Transformer模型在文献[1]中提出,它不再使用RNN来处理序列问题,而是完全采用注意力机制通过编码、解码和上下文向量等技术,克服了RNN难以并行处理的缺陷,取得了很好的效果,是目前应用最为广泛的处理序列问题的方法。
1 总体结构与应用示例
图15-1 Transformer总体结构[1]
模型的总体结构如图15-1所示,为了便于读者理解,先将模型看成一个黑盒子,以英译中应用示例来说明模型的总体工作过程。
Transformer有2个输入口和1个输出口。2个输入口如图 64中最下面的Inputs和Outputs所示,它们分别接收训练样本的实例(英文序列)和标签(中文序列)。1个输出口如图15-1中最上面的Output Probabilities所示,它输出当前预测。
用英译中任务来概要说明Transformer的工作过程,包括训练过程和预测过程。
1)训练过程
英译中任务的训练样本的英文序列从Inputs输入口输入,中文序列从Outputs输入口输入。
一个样本的英文序列是一次性整句输入的。对应的中文序列不是一次性整句输入模型的,而是分多次输入模型进行训练的。第一次输入Outputs输入口的是一个标志句子起始的特殊符号,一般记为。以后每次输入的都是依序从句子取一个元素附加到后面形成的序列。例如,对于句子“我爱大模型”,每次输入分别为:“”、“我”、“我爱”、…、“我爱大模型”。对应每次输入的Output Probabilities模型输出,是对下一个元素的预测,例如,对于输入“我爱”,此时模型的输出是对“大”字的预测值。该预测值与“大”字的实际表示值的误差用于反向传播学习模型的参数。对于整句“我爱大模型”的输入,模型输出的是对标志句子结束的特殊符号的预测。
2)预测过程
预测时,先将英文序列完整输入Inputs输入口。然后将输入Outputs输入口,从Output Probabilities模型输出口中得到一个预测的中文元素。该元素附在之后一起从Outputs输入,再次预测得到一个中文元素。如此这般重复运行,直到预测输出,结束对英文句子的中文翻译。句子“我爱大模型”的正确预测过程如所图15-2示意。
图15-2 Transformer预测过程示意
PyTorch在不同粒度上提供了对Transformer的支持。为了便于读者理解,本节将从整体到细节,层层深入地讨论Transformer模型。
torch.nn.Transformer类实现了图15-1中的虚线框内部分,它封装了实现的细节,便于应用。下面先用它为主来实现一个示例以说明Transformer的应用和总体工作过程,如代码15-1所示,对该示例的说明分为准备数据、构建模型、训练模型和推理四部分进行。
1)准备数据
第一部分是完成数据准备工作,先从文件中读取训练语料,然后建立对每个语言元素编号的词表,最后对语料中的中英文句子添加开始标记和结束标记,并用编号表示。要注意的是,为了突出主题,本示例采用最简单的语言元素,即英文中的字母和中文中的字。同时,对模型的训练也没有采用批梯度下降法,只采用了逐条训练的方法。
2)构建模型
第二部分首先定义了TransformerTranslator类,构建了示例用的Transformer模型。其中定义了两个词向量层,分别对应图15-1中的Input Embedding和Output Embedding。
接下来定义了图 15-1中虚框部分,由实例化torch.nn.Transformer类得到,主要参数含义为:d_model表示模型内部向量的维数;num_encoder_layers表示编码器的层数(下文详述);num_decoder_layers=num_layers表示解码器的层数(下文详述);dim_feedforward表示编码器和解码器内部前馈神经网络的隐藏层维度(下文详述)。
随后定义了一个线性层,对应图15-1上部的Linear。
随后定义了一个对位置进行编码的方法,其输出将叠加到词向量上,如图15-1中Positional Encoding所示。下面展开讨论该位置编码的含义及原理。
由 Attention 的讨论可知,Attention 的处理过程不能体现序列数据的位置关系,而 Transformer 模型也不再使用善于处理序列数据的循环神经网络,因此,需要引入一种能够感知和处理元素在序列中的位置的信息的方法。通常将这类方法称为位置编码(Position Encoding)或者位置嵌入(Position embedding)。Transformer 的原论文 [1] 采用了一种直接计算元素在序列中绝对位置的方法。设元素在序列中的位置序号为pospospos,元素的位置编码用长度为ddd的向量来表示,该位置向量的计算方法为:
PE(pos,2i)=sin(pos100002i/d)PE(pos,2i+1)=cos(pos100002i/d)(15-1) PE_{(pos,2i)} = \sin\left(\frac{pos}{10000^{2i/d}}\right) \\ PE_{(pos,2i+1)} = \cos\left(\frac{pos}{10000^{2i/d}}\right) \tag{15-1}PE(pos,2i)=sin(100002i/dpos)PE(pos,2i+1)=cos(100002i/dpos)(15-1)
其中,iii取值0∼d20 \sim \frac{d}{2}0∼2d,PE(pos,2i)PE_{(pos,2i)}PE(pos,2i)表示序号为偶数的向量元素,PE(pos,2i+1)PE_{(pos,2i+1)}PE(pos,2i+1)表示序号为奇数的向量元素。
如果用 3 维的向量来对字在句子中的位置进行编码,则句子中的第 1 个字(例如,句子“机器学习是非常重要的研究和应用领域”中的“机”字)的位置向量为:[PE(1,0),PE(1,1),PE(1,2)]=[sin(1100000/3),cos(1100001/3),sin(1100002/3)]≈[0.0,1.0,0.0][PE_{(1,0)}, PE_{(1,1)}, PE_{(1,2)}] = \left[ \sin\left(\frac{1}{10000^{0/3}}\right), \cos\left(\frac{1}{10000^{1/3}}\right), \sin\left(\frac{1}{10000^{2/3}}\right) \right] \approx [0.0, 1.0, 0.0][PE(1,0),PE(1,1),PE(1,2)]=[sin(100000/31),cos(100001/31),sin(100002/31)]≈[0.0,1.0,0.0],第 2 个字和第 3 个字的位置向量分别为[0.84,0.54,0.0][0.84, 0.54, 0.0][0.84,0.54,0.0]和[0.91,−0.42,0.0][0.91, -0.42, 0.0][0.91,−0.42,0.0]。
这种直接计算位置向量的方法利用了sin\sinsin和cos\coscos函数在不同时刻的非线性变化特点来产生可以让模型识别的位置信息。计算得到的位置向量通常是与词向量相加后输入编码器。
位置编码方法的具体实现可参考代码及其注释。
随后定义了生成序列掩码的方法,该方法生成一个上三角为负无穷、其他元素为0的矩阵。该矩阵的一行与中文整句一起用于计算注意力,从而实现了将中文句子逐字扩展输入模型。
随后定义了代表前向传播的forward方法,它的输入是英文句子src和中文句子tgt,输出是对下一个中文的预测值output。src经过词嵌入、叠加位置编码后输入self.transformer。tgt经过类似的过程后也输入self.transformer。同时输入self.transformer的还有中文句子的掩码tgt_mask。self.transformer的输出经过线性层后得到模型的输出。
代码15-1 应用torch.nn.Transformer类实现英译中示例
importtorchimporttorch.nnasnnimporttorch.optimasoptimimportnumpyasnp# 设置随机种子torch.manual_seed(1121)# 1. 准备数据classTranslationDataset:def__init__(self,data_path,n_samples):# 从data_path读取文件,取n_samples条语料,构建词表self.raw_data=[]withopen(data_path,'r',encoding='utf-8')asfile:n=0forlineinfile:line=line.strip()# 去掉换行符等ifline:# 跳过空行sentences=line.split('\t')# Tab符号隔开的英文和中文self.raw_data.append(sentences)n+=1ifn>=n_samples:break# 构建词表,对基本语言元素进行编号self.src_vocab={"<pad>":0,"<bos>":1,"<eos>":2,"<unk>":3}# 特殊标记符号self.tgt_vocab={"<pad>":0,"<bos>":1,"<eos>":2,"<unk>":3}# 下面是把句子拆分成基本语言元素。要注意的是,为了突出主题,方便读者理解,# 同时降低演示代码运行要求,这里是把英文句子拆成字母了,并没有以单词为基本# 语言元素,同样汉语句子也拆成了单字,没有分词。src_words=set()# 用集合来存放所有基本语言元素tgt_words=set()foren,zhinself.raw_data:# 依次取每一条语料src_words.update(list(en.lower()))# list()函数分解所有字母和符号tgt_words.update(list(zh))forwordinsrc_words:self.src_vocab[word]=len(self.src_vocab)# 依次递增编号forword