在自然语言处理seq2seq模型中,模型训练完成后,预测推理时需要预测每一步输出的最可能的单词,之后组合成完整的预测输出句子。这里每一步最可能的输出单词的选择就用到greedy search或者beam search。下面详细介绍一下这两种搜索的区别,以及实现方法。
greedy search比较简单,就是贪婪式的搜索,每一步都选择概率最大的单词输出,最后组成整个句子输出。这种方法给出的结果一般情况结果比较差,因为只考虑了每一步的最优解,往往里全局最优解差距很大。
贪婪搜索实现比较简单,这里就不写了,每一部找一个最大值就好了。
beam search是介于全局搜索和贪婪搜索之间。这里先讲一下全局搜索,全局搜索考虑的是全局最优解,需要把每一种可能输出结果都算出来,然后找出概率最大的输出。这种搜索空间是非常巨大的,假设我们的词表大小为N,句子长度为T个单词,整个搜索时间复杂度为O(N*N*T),一般N取值在几万到几十万级别,T在为几百个单词,实际计算比较慢。
而降低时间复杂度的方法就是寻找次优解,具体就是把搜索空间中的N减下来,每一步计算完只保留K个最大的取值路径,这样时间复杂度降为O(K*N*T),K取值一般比N小很多。这样得到的虽然不是最优解,但是在seq2seq模型的推理预测中可以兼顾时间和效果。
下面用python简单实现一下beam-search算法,这里实现的是假设每一步输出的可能概率是提前算好的,然后传入一个k值,计算beam-search的最优输出。
实际seq2seq中beam-search是要考虑每一部状态输出的,下一步预测输出需要上一步的状态,这里的实现暂时未考虑。
写法1:每次先排序,然后选择排序后的k个最大值。
写法2:利用快排的思想,找出第k大的值,根据partition划分直接找到k个最大的值。 numpy中argpartition可以完成这个操作。
写法3:利用堆排序,建立一个k个元素的小顶堆,每次通过与堆顶比较判断大小,并更新堆,最后完成topk最大的目标。 heapq中有现成的方法nlargest。