本文作者为李炎洋,来自东北大学自然语言处理实验室,向雷锋网AI科技评论独家投稿。
对于目前基于神经网络的序列模型,很重要的一个任务就是从序列模型中采样。比如解码时我们希望能产生多个不一样的结果,而传统的解码算法只能产生相似的结果。又比如训练时使用基于强化学习或者最小风险训练的方法需要从模型中随机采集多个不一样的样本来计算句子级的损失,而一般的确定性方法不能提供所需要的随机性。本文回顾了一系列常用的序列模型采样方法,包括基于蒙特卡洛的随机采样和随机束搜索,以及最近提出的基于Gumbel-Top-K的随机束搜索。表1展示了这三种方法各自的优缺点。
在此之前,我们首先回顾一下束搜索。在序列模型中,束搜索通常被用来提升模型解码时的性能。默认的贪婪解码总是在每一步挑选一个当前分数最高的词来组成序列。相比起贪婪解码,束搜索每一步都挑选多个词来组成多个候选序列,最后挑选分数最高的序列作为最终输出。束搜索虽然增加了计算量,但是也显著提升了模型性能。图1是一个束大小为2的束搜索的例子:
图1 束搜索第一步
在解码第一步的时候,束搜索从句子开始符<START>开始,根据模型的打分logPLM(PLM是在给定前缀的情况下模型输出的下一词分布)来挑选词表中得分最高的前两个词he和I,并用he和I的得分logPLM(he|<START> )和logPLM(I|<START> )分别作为候选序列<START> he和<START> I的得分。
图2 计算束搜索第二步打分
在解码第二步的时候,根据模型的打分logPLM为已经生成部分内容的句子<START> he和<START> I各自挑选得分最高的前两个词,如<START> he会挑选hit和struck,<START> I会挑选was和got,然后组成一共四个候选序列<START> he hit,<START> he struck,<START> I was和<START> I got,并分别计算他们的得分,比如<START> he hit的得分等于<START> he这个序列的得分加上hit的得分 logPLM(hit|<START> he),如图2所示。最后保留这四个候选序列中得分最高的前两个序列
即<START> he hit和<START> I was,如图3所示。
图3 挑选束搜索第二步候选
以此类推,束搜索一直迭代到固定次数或者所有的候选序列都结束才停止。在这个例子中束搜索在第六步停止,产生了两个候选序列<START> he hit me with a pie和<START> he hit me with a tart,并挑选得分最高的<START> he hit me with a pie作为最终的结果,如图4所示。
图4 束搜索最终结果
从序列模型中采集多个样本有两种经典的方法:基于蒙特卡洛的随机采样和基于蒙特卡洛的束搜索。
在序列模型中采样的最简单方法就是在贪婪搜索的基础上,在每一步挑选下一个词的时候不是根据它们相应的得分logPLM而是根据模型输出的下一个词分布PLM来随机选取一个,这样重复到固定长度或者挑选到句子结束符时停止。这样我们获得了一个样本。如果需要采集多个样本,那么重复这个过程若干次便可得到多个样本。
基于蒙特卡洛的随机采样虽然简单,但是它面临着严重的效率问题。如果模型输出的下一个词分布PLM熵很低,即对于个别词输出概率特别高,那么采集到的样本将有很大一部分重复,比如接近收敛时候的模型。因此为了采集到固定数目的不同样本,基于蒙特卡洛的随机采样可能需要远远大于所需样本数的采样次数,使得采样过程十分低效。
基于蒙特卡洛的随机束搜索在采集多个不同样本远比基于蒙特卡洛的随机采样高效。假设现在束大小为K,基于蒙特卡洛的随机束搜索在束搜索的基础上,把根据下一词的得分logPLM挑选前K个得分最高的词的操作替换成根据下一个词分布PLM随机挑选K个不同词。因为每一步都挑选了不同的词,因此最终产生的K个候选序列都不会相同,从而达到了高效采集K个样本的目的。
但是基于蒙特卡洛的随机束搜索也面临着方差的问题。在每一步中它都是根据PLM随机挑选K个不同词,它无法控制随机采样时的噪声,也就是样本分布的方差跟每一步的PLM的方差相关,而PLM的方差是无法控制的,它可能非常大也可能非常小。因此在基于蒙特卡洛的随机束搜索采集到的样本上估计的统计量会非常不稳定,比如在使用句子级损失的任务中采用样本估计损失的时候会计算出不稳定的值,使模型训练受到影响。
解决基于蒙特卡洛的随机束搜索的问题关键在于怎么控制每一步随机采样时的噪声。最近的论文提出使用了Gumbel-Top-K技巧来达到这个目的。
如果我们把每个可能的句子当成一个单独的类别来构造一个类别数非常庞大(假设所有句子长度相等,那么有VT个类别,其中V是词表大小,T是句子长度)的类别分布,那么便可以使用Gumbel-Top-K技巧来从这一个庞大的类别分布中采集K个不同样本,同时每个样本都服从于原始的分布。这也是论文提出的自底向上的采样方法。
图5 自底向上的采样方法
图5展示了一个词表大小V=3(hello,world,!),句子长度T=3和样本数K=2的例子。我们需要先从第一个词开始枚举所有的9个可能的句子,同时使用模型计算这9个句子的概率。因为模型通常只能计算整个句子的概率,而Gumbel噪声需要加到整个logit上,我们可以使用整个句子的对数概率
我们就完成了采样,但是自顶向上的方法需要先枚举所有句子和计算其对数概率才能开始使用噪声扰动每个句子的对数概率,那么我们能不能从句子开始一边枚举一边计算和扰动生成的不同句子的对数概率?在此之前,我们必须先定义在枚举过程中中间生成的只有部分内容的句子的对数扰动概率。只有部分内容的句子(部分生成的句子)的对数扰动概率,比如例子中的<START> world,定义为以该部分生成的句子为前缀的所有完整句子中对数扰动概率最大的一个
这样,我们可以一边枚举所有句子的同时计算句子的对数扰动概率。
更进一步地,我们可以看到,因为我们定义部分生成的句子的对数扰动概率为其对应的所有完整句子的最大的对数扰动概率,因此如果我们在枚举的时候只保留分数最高的K个候选,那么我们可以保证最终的K个候选一定是所有句子中分数最高的前K个,因为部分生成的句子的对数扰动概率的定义已经说明一个内部节点的所有叶子节点的对数扰动概率不可能比它的对数扰动概率大,因此在当前一层中不是分数最高的前K个的话以后它任何一个后代节点也不可能是分数最高的前K个。这样一个自顶向下的方法可以非常高效的采集K个不同样本而不需要枚举所有句子。
图6 自顶向下的采样方法
图6展示了一个K=2的自顶向下的采样例子。我们先对<START>的对数概率进行扰动,得到-1.2,然后我们对所有候选序列<START> hello,<START> !和<START> world的对数概率进行扰动并进行纠正,得到-4.3,-3.2,-1.2,最后我们只保留对数扰动概率最高的<START> !和<START> world继续进行拓展,最终得到<START> world hello和<START> world world两个样本。
最新提出的基于Gumbel-Top-K的随机束搜索提供了一种高效的采样手段。利用这种方法,我们可以:
1. 对于需要采样来计算句子级损失的任务,可以更高效地训练模型;
2. 类似于使用Gumbel-Softmax的梯度作为Gumbel-Max梯度的有偏估计,为Gumbel-Top-K寻找类 似的梯度有偏估计,使得模型可以直接优化其搜索过程;
3. 概率化束搜索,为束搜索可能导致的一系列问题如过翻译,漏译等提供概率解释。
参考文献
Kool, W., Hoof, H.V., & Welling, M. (2019). Stochastic Beams and Where To Find Them: The Gumbel-Top-k Trick for Sampling Sequences Without Replacement. ICML.
Shen, S., Cheng, Y., He, Z., He, W., Wu, H., Sun, M., & Liu, Y. (2015). Minimum Risk Training for Neural Machine Translation. ArXiv, abs/1512.02433.
作者介绍
李炎洋,东北大学自然语言处理实验室研究助理。东北大学自然语言处理实验室由姚天顺教授创建于 1980 年,现由朱靖波教授、肖桐博士领导,长期从事计算语言学的相关研究工作,主要包括机器翻译、语言分析、文本挖掘等。团队研发的支持119种语言互译的小牛翻译系统已经得到广泛应用。