TextRNN的PyTorch实现

本文介绍一下如何使用PyTorch复现TextRNN,实现预测一句话的下一个词

参考这篇论文Finding Structure in Time(1990),如果你对RNN有一定的了解,实际上不用看,仔细看我代码如何实现即可。如果你对RNN不太了解,请仔细阅读我这篇文章RNN Layer,结合PyTorch讲的很详细

现在问题的背景是,我有n句话,每句话都由且仅由3个单词组成。我要做的是,将每句话的前两个单词作为输入,最后一词作为输出,训练一个RNN模型

导库

'''
  code by Tae Hwan Jung(Jeff Jung) @graykode, modify by wmathor
'''
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data

dtype = torch.FloatTensor

准备数据

sentences = [ "i like dog", "i love coffee", "i hate milk"]

word_list = " ".join(sentences).split()
vocab = list(set(word_list))
word2idx = {w: i for i, w in enumerate(vocab)}
idx2word = {i: w for i, w in enumerate(vocab)}
n_class = len(vocab)

预处理数据,构建Dataset,定义DataLoader,输入数据用one-hot编码

# TextRNN Parameter
batch_size = 2
n_step = 2 # number of cells(= number of Step)
n_hidden = 5 # number of hidden units in one cell

def make_data(sentences):
    input_batch = []
    target_batch = []

    for sen in sentences:
        word = sen.split()
        input = [word2idx[n] for n in word[:-1]]
        target = word2idx[word[-1]]

        input_batch.append(np.eye(n_class)[input])
        target_batch.append(target)

    return input_batch, target_batch

input_batch, target_batch = make_data(sentences)
input_batch, target_batch = torch.Tensor(input_batch), torch.LongTensor(target_batch)
dataset = Data.TensorDataset(input_batch, target_batch)
loader = Data.DataLoader(dataset, batch_size, True)

以上的代码我想大家应该都没有问题,接下来就是定义网络架构

class TextRNN(nn.Module):
    def __init__(self):
        super(TextRNN, self).__init__()
        self.rnn = nn.RNN(input_size=n_class, hidden_size=n_hidden)
        # fc
        self.fc = nn.Linear(n_hidden, n_class)

    def forward(self, hidden, X):
        # X: [batch_size, n_step, n_class]
        X = X.transpose(0, 1) # X : [n_step, batch_size, n_class]
        out, hidden = self.rnn(X, hidden)
        # out : [n_step, batch_size, num_directions(=1) * n_hidden]
        # hidden : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]
        out = out[-1] # [batch_size, num_directions(=1) * n_hidden] ⭐
        model = self.fc(out)
        return model

model = TextRNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

以上代码每一步都值得说一下,首先是nn.RNN(input_size, hidden_size)的两个参数,input_size表示每个词的编码维度,由于我是用的one-hot编码,而不是WordEmbedding,所以input_size就等于词库的大小len(vocab),即n_class。然后是hidden_size,这个参数没有固定的要求,你想将输入数据的维度转为多少维,就设定多少

对于通常的神经网络来说,输入数据的第一个维度一般都是batch_size。而PyTorch中nn.RNN()要求将batch_size放在第二个维度上,所以需要使用x.transpose(0, 1)将输入数据的第一个维度和第二个维度互换

然后是rnn的输出,rnn会返回两个结果,即上面代码的out和hidden,关于这两个变量的区别,我在之前的博客也提到过了,如果不清楚,可以看我上面提到的RNN Layer这篇博客。这里简单说就是,out指的是下图的红框框起来的所有值;hidden指的是下图蓝框框起来的所有值。我们需要的是最后时刻的最后一层输出,即Y3Y_3的值,所以使用out=out[-1]将其获取

剩下的部分就比较简单了,训练测试即可

# Training
for epoch in range(5000):
    for x, y in loader:
      # hidden : [num_layers * num_directions, batch, hidden_size]
      hidden = torch.zeros(1, x.shape[0], n_hidden)
      # x : [batch_size, n_step, n_class]
      pred = model(hidden, x)

      # pred : [batch_size, n_class], y : [batch_size] (LongTensor, not one-hot)
      loss = criterion(pred, y)
      if (epoch + 1) % 1000 == 0:
          print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
    
input = [sen.split()[:2] for sen in sentences]
# Predict
hidden = torch.zeros(1, len(input), n_hidden)
predict = model(hidden, input_batch).data.max(1, keepdim=True)[1]
print([sen.split()[:2] for sen in sentences], '->', [idx2word[n.item()] for n in predict.squeeze()])

完整代码如下

'''
  code by Tae Hwan Jung(Jeff Jung) @graykode, modify by wmathor
'''
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data

dtype = torch.FloatTensor

sentences = [ "i like dog", "i love coffee", "i hate milk"]

word_list = " ".join(sentences).split()
vocab = list(set(word_list))
word2idx = {w: i for i, w in enumerate(vocab)}
idx2word = {i: w for i, w in enumerate(vocab)}
n_class = len(vocab)

# TextRNN Parameter
batch_size = 2
n_step = 2 # number of cells(= number of Step)
n_hidden = 5 # number of hidden units in one cell

def make_data(sentences):
    input_batch = []
    target_batch = []

    for sen in sentences:
        word = sen.split()
        input = [word2idx[n] for n in word[:-1]]
        target = word2idx[word[-1]]

        input_batch.append(np.eye(n_class)[input])
        target_batch.append(target)

    return input_batch, target_batch

input_batch, target_batch = make_data(sentences)
input_batch, target_batch = torch.Tensor(input_batch), torch.LongTensor(target_batch)
dataset = Data.TensorDataset(input_batch, target_batch)
loader = Data.DataLoader(dataset, batch_size, True)

class TextRNN(nn.Module):
    def __init__(self):
        super(TextRNN, self).__init__()
        self.rnn = nn.RNN(input_size=n_class, hidden_size=n_hidden)
        # fc
        self.fc = nn.Linear(n_hidden, n_class)

    def forward(self, hidden, X):
        # X: [batch_size, n_step, n_class]
        X = X.transpose(0, 1) # X : [n_step, batch_size, n_class]
        out, hidden = self.rnn(X, hidden)
        # out : [n_step, batch_size, num_directions(=1) * n_hidden]
        # hidden : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]
        out = out[-1] # [batch_size, num_directions(=1) * n_hidden] ⭐
        model = self.fc(out)
        return model

model = TextRNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training
for epoch in range(5000):
    for x, y in loader:
      # hidden : [num_layers * num_directions, batch, hidden_size]
      hidden = torch.zeros(1, x.shape[0], n_hidden)
      # x : [batch_size, n_step, n_class]
      pred = model(hidden, x)

      # pred : [batch_size, n_class], y : [batch_size] (LongTensor, not one-hot)
      loss = criterion(pred, y)
      if (epoch + 1) % 1000 == 0:
          print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
  
input = [sen.split()[:2] for sen in sentences]
# Predict
hidden = torch.zeros(1, len(input), n_hidden)
predict = model(hidden, input_batch).data.max(1, keepdim=True)[1]
print([sen.split()[:2] for sen in sentences], '->', [idx2word[n.item()] for n in predict.squeeze()])

热门文章

暂无图片
编程学习 ·

POJ练习题之:败方树

问题描述 给定一个整数数组,要求对数组中的元素构建败方树(数组相邻元素两两比较,从第一个元素开始)。之后修改数组中的元素,要求输出初始构建以及修改后得到的败方树的所有内部结点代表的整数(从左到右从上到下输出) 输入 第一行为数组的元素个数n和修改的次数m。 第二行…
暂无图片
编程学习 ·

Paddle_程序员必备的数学知识_转发

程序员——必备数学知识!!!Attention 本博客转发至百度aistudio的<深度学习7日入门-cv疫情检测>,课程非常棒!本人力推! 博客转发地址:https://aistudio.baidu.com/aistudio/projectdetail/604807 课程报名地址:https://aistudio.baidu.com/aistudio/education/group/i…
暂无图片
编程学习 ·

SSM整合小案例

SSM整合 数据库部分(Oracle)创建表 CREATE TABLE product( id varchar2(32) default SYS_GUID() PRIMARY KEY, productNum VARCHAR2(50) NOT NULL, productName VARCHAR2(50), cityName VARCHAR2(50), DepartureTime timestamp, productPrice Number, productDesc VARCHAR2(500…
暂无图片
编程学习 ·

算法复杂度评价指标(大o表示法)

大O表示法(1)常见的大o数量级函数(2)其他算法复杂度表示法 基本操作数量函数T(n)的精确值并不是特别重要,重要的是Tn(n)中起决定性因素的主导部分。用动态的眼光看,就是当问题规模增大的时候,T(n)中的一些部分会盖过其他部分的贡献。 数量级函数描述了T(n)中随着n增加而…
暂无图片
编程学习 ·

CSS滚动指示器

一、CSS滚动指示器 滚动指示器指的是页面的顶端会有一个进度条,指示滚动的进度。效果如下GIF所示(点击播放):CSS滚动指示器指的是不借助JavaScript,纯CSS实现滚动进度效果。 二、传统的实现方法 传统CSS实现方法由一个名叫 Mike的人首先提出,时间应该是16年,这个CodePen…
暂无图片
编程学习 ·

分布式数据存储系统之三要素

什么是分布式数据存储系统? 分布式存储系统的核心逻辑,就是将用户需要存储的数据根据某种规则存储到不同的机器上,当用户想要获取指定数据时,再按照规则到存储数据的机器里获取。 如下图所示,当用户(即应用程序)想要访问数据 D,分布式操作引擎通过一些映射方式,比如 H…
暂无图片
编程学习 ·

DAY14 Javaweb Servlet、Response、Request

以下讲的都是最底层的内容,以后会被新的方法顶替掉一、Servlet,是sun公司开发的一门技术,如果要开发sevlet程序(网页java),只需要1、实现这个接口就可以 2、把开发好的java类部署到web服务器中。把实现了Servlet接口的Java程序叫做Servlet,一个请求地址对应一个servlet…
暂无图片
编程学习 ·

Android编程权威指南总结(六)

第十七章 双版面主从用户界面本章是为了适应平板设备。双版面主从用户界面,也就是平板上的列表和详情界面同时展示的情况。一、增加布局灵活性双版面布局里面,一个 Activity 托管两个 Fragment。1、方法上使用 @LayoutRes 注解,这告诉Android Studio,任何时候该注解的…
暂无图片
编程学习 ·

Spring Boot 集成 WebSocket 实现服务端推送消息到客户端

假设有这样一个场景:服务端的资源经常在更新,客户端需要尽量及时地了解到这些更新发生后展示给用户,如果是 HTTP 1.1,通常会开启 ajax 请求询问服务端是否有更新,通过定时器反复轮询服务端响应的资源是否有更新。ajax 轮询在长时间不更新的情况下,反复地去询问会对服务器…
暂无图片
编程学习 ·

Explicit Model Predictive Control of a Magnetic Flexible Endoscope

对胶囊的动力学进行建模,能更好的对胶囊进行控制,在已知胶囊预定义轨迹的情况下,对胶囊进行预测控制和定位。 一个磁灵活内窥镜的显式模型预测控制 Explicit Model Predictive Control of a Magnetic Flexible Endoscope [1] Paper Link Authors: Scaglioni, Bruno, et al. …
暂无图片
编程学习 ·

Spring Boot + RabbitMQ 配置参数解释

application.properties配置文件写法#rabbitmq spring.rabbitmq.virtual-host=/ spring.rabbitmq.host=192.168.124.20 spring.rabbitmq.port=5672 spring.rabbitmq.username=guest spring.rabbitmq.password=guest spring.rabbitmq.listener.concurrency=10 spring.rabbitmq.l…
暂无图片
编程学习 ·

【Android开发--新手必看篇】Calendar类的使用

Android笔记 ​ ——其他 【若对该知识点有更多想了解的,欢迎私信博主~~】 Calendar类: 获取日期 注:在JDK1.0中,Date类是唯一处理时间的类,但是由于Date类中方法比较少并且有一些方法不便于实现国际化,所以从JDK1.1版本开始新增了Calendar类,增加了许多功能强大的方法…
暂无图片
编程学习 ·

Java 常用算法

算法一:分治法 基本概念 1.把一个复杂的问题分成两个或更多的相同或相似的子问题,再把子问题分成更小的子问题……直到最后子问题可以简单的直接求解,原问题的解即子问题的解的合并。 2.分治策略是对于一个规模为n的问题,若该问题可以容易地解决(比如说规模n较小)则直接解…
暂无图片
编程学习 ·

阿里云云计算ACA练习题5

如果将跨阿里云账号的ECS自建数据库迁移至云数据库RDS,以下哪一项不是必须配置的信息? A. 目标RDS实例IDB. RDS实例访问账号C. RDS实例IP地址D. RDS实例访问账号对应的密码【参考答案】:B云计算面临的主要安全威胁,按照对系统的影响维度可以分为?(正确答案的数量:3) (多选…
暂无图片
编程学习 ·

JMXTrans入门教程

概述 官网 GitHub JMX JMX,即Java Management Extensions,监控Java应用程序系统运行的状态信息,通过分析JMX信息,可用于监控应用程序运行状态、优化程序、排查问题。 JMXTrans JMXTrans是一款开源的JMX指标采集工具,使用简单方便,无需编写代码,只需要配置文件就可以轻松…
暂无图片
编程学习 ·

Go 结构体使用注意事项和细节

结构体使用注意事项和细节结构体的所有字段在内存中是连续的//结构体 type Point struct {x inty int }//结构体 type Rect struct {leftUp, rightDown Point }func main() {r1 := Rect{Point{1,2}, Point{3,4}}//r1有四个int, 在内存中是连续分布//打印地址fmt.Printf("…
暂无图片
编程学习 ·

Java数据结构--数组、矩阵、广义表

一、简介 1.1 数组的概念是n(n ≥ 1)个相同数据类型的数据元素a0,a1,…,an-1构成的占用一块联系地址的内存单元的有限集合。1.2 特点(1)数组中数据元素的数据类型相同; (2)数组是一种随机存取结构,只要给定一组下标,就可以访问与其对应的数组元素; (3)数组中数据元素…