在科学研究中,从方法论上来讲,都应“先见森林,再见树木”。当前,人工智能学术研究方兴未艾,技术迅猛发展,可谓万木争荣,日新月异。对于AI从业者来说,在广袤的知识森林中,系统梳理脉络,才能更好地把握趋势。为此,我们精选国内外优秀的综述文章,开辟“综述专栏”,敬请关注。
作者: DengBoCong地址:https://www.zhihu.com/people/dengbocong
本文打算讨论在深度学习中常用的十余种损失函数(含变种),结合PyTorch和TensorFlow2对其概念、公式及用途进行阐述,希望能达到看过的伙伴对各种损失函数有个大致的了解以及使用。本文对原理只是浅尝辄止,不进行深挖,感兴趣的伙伴可以针对每个部分深入翻阅资料。
使用版本:
TensorFlow2.3
PyTorch1.7.0
01
交叉熵损失(CrossEntropyLoss)
对于单事件的信息量而言,当事件发生的概率越大时,信息量越小,需要明确的是,信息量是对于单个事件来说的,实际事件存在很多种可能,所以这个时候熵就派上用场了,熵是表示随机变量不确定的度量,是对所有可能发生的事件产生的信息量的期望。交叉熵用来描述两个分布之间的差距,交叉熵越小,假设分布离真实分布越近,模型越好。
在分类问题模型中(不一定是二分类),如逻辑回归、神经网络等,在这些模型的最后通常会经过一个sigmoid函数(softmax函数),输出一个概率值(一组概率值),这个概率值反映了预测为正类的可能性(一组概率值反应了所有分类的可能性)。而对于预测的概率分布和真实的概率分布之间,使用交叉熵来计算他们之间的差距,换句不严谨的话来说,交叉熵损失函数的输入,是softmax或者sigmoid函数的输出。交叉熵损失可以从理论公式推导出几个结论(优点),具体公式推导不在这里详细讲解,如下:
预测的值跟目标值越远时,参数调整就越快,收敛就越快;
不会陷入局部最优解
交叉熵损失函数的标准形式(也就是二分类交叉熵损失)如下:
多分类交叉熵损失如下:
Tensorflow:
BinaryCrossentropy[1]:二分类,经常搭配Sigmoid使用
binary_crossentropy[2]
CategoricalCrossentropy[3]:多分类,经常搭配Softmax使用
categorical_crossentropy[4]
SparseCategoricalCrossentropy[5]:多分类,经常搭配Softmax使用,和CategoricalCrossentropy不同之处在于,CategoricalCrossentropy是one-hot编码,而SparseCategoricalCrossentropy使用一个位置整数表示类别
sparse_categorical_crossentropy[6]
PyTorch:
BCELoss[7]
BCEWithLogitsLoss[8]:其实和TensorFlow是的`from_logits`参数很像,在BCELoss的基础上合并了Sigmoid
CrossEntropyLoss[9]
我们在计算预测和真实标签之间损失时,需要拉近他们分布之间的差距,即模型得到的预测分布应该与数据的实际分布情况尽可能相近。KL散度(相对熵)是用来衡量两个概率分布之间的差异。模型需要得到最大似然估计,乘以负Log以后就相当于求最小值,此时等价于求最小化KL散度(相对熵)。所以得到KL散度就得到了最大似然。又因为KL散度中包含两个部分,第一部分是交叉熵,第二部分是信息熵,即KL=交叉熵−信息熵。信息熵是消除不确定性所需信息量的度量,简单来说就是真实的概率分布,而这部分是固定的,所以优化KL散度就是近似于优化交叉熵。下面是KL散度的公式:
联系上面的交叉熵,我们可以将公式简化为(KL散度 = 交叉熵 - 熵):
监督学习中,因为训练集中每个样本的标签是已知的,此时标签和预测的标签之间的KL散度等价于交叉熵。
TensorFlow:
KLD | kullback_leibler_divergence[10]
KLDivergence[11]
Pytorch:
KLDivLoss[12]
03
平均绝对误差(L1范数损失)
缺点:
梯度恒定,不论预测值是否接近真实值,这很容易导致发散,或者错过极值点。
导数不连续,导致求解困难。这也是L1损失函数不广泛使用的主要原因。
优点:
收敛速度比L2损失函数要快,这是通过对比函数图像得出来的,L1能提供更大且稳定的梯度。
对异常的离群点有更好的鲁棒性,下面会以例子证实。
TensorFlow:
MAE | mean_absolute_error[13]
MeanAbsoluteError[14]
MeanAbsolutePercentageError[15]:平均绝对百分比误差
MAPE | mean_absolute_percentage_error[16]:平均绝对百分比误差
Huber[17]
PyTorch:
L1Loss[18]
l1_loss[19]
SmoothL1Loss[20]:平滑版L1损失,也被称为 Huber 损失函数。
smooth_l1_loss[21]
04
均方误差损失(L2范数损失)
缺点:
收敛速度比L1慢,因为梯度会随着预测值接近真实值而不断减小。
对异常数据比L1敏感,这是平方项引起的,异常数据会引起很大的损失。
优点:
它使训练更容易,因为它的梯度随着预测值接近真实值而不断减小,那么它不会轻易错过极值点,但也容易陷入局部最优。
它的导数具有封闭解,优化和编程非常容易,所以很多回归任务都是用MSE作为损失函数。
TensorFlow:
MeanSquaredError[22]
MSE | mean_squared_error[23]
MeanSquaredLogarithmicError[24]
MSLE | mean_squared_logarithmic_error[25]
PyTorch:
MSELoss[26]
mse_loss[27]
05
Hinge loss
扩展到多分类问题上就需要多加一个边界值,然后叠加起来。公式如下:
Tensorflow:
CategoricalHinge[28]
categorical_hinge[29]
Hinge[30]
hinge[31]
SquaredHinge[32]
squared_hinge[33]
PyTorch:
06
余弦相似度
余弦相似度是机器学习中的一个重要概念,在Mahout等MLlib中有几种常用的相似度计算方法,如欧氏相似度,皮尔逊相似度,余弦相似度,Tanimoto相似度等。其中,余弦相似度是其中重要的一种。余弦相似度用向量空间中两个向量夹角的余弦值作为衡量两个个体间差异的大小。相比距离度量,余弦相似度更加注重两个向量在方向上的差异,而非距离或长度上。
余弦相似度更多的是从方向上区分差异,而对绝对的数值不敏感,更多的用于使用用户对内容评分来区分用户兴趣的相似度和差异,同时修正了用户间可能存在的度量标准不统一的问题(因为余弦相似度对绝对数值不敏感),公式如下:
Tensorflow:
CosineSimilarity[35]:请注意,所得值是介于-1和0之间的负数,其中0表示正交性,而接近-1的值表示更大的相似性。如果y_true或y_pred是零向量,则余弦相似度将为0,而与预测值和目标值之间的接近程度无关。
cosine_similarity[36]
PyTorch:
07
总结
上面这些损失函数是我们在日常中经常使用到的,我将TensorFlow和PyTorch相关的API都贴出来了,也方便查看,可以作为一个手册文章,需要的时候点出来看一下。还有一些其他的损失函数,后续也会都加进来。
[1] https://www.tensorflow.org/api_docs/python/tf/keras/losses/BinaryCrossentropy
[2] https://www.tensorflow.org/api_docs/python/tf/keras/losses/binary_crossentropy
[3] https://www.tensorflow.org/api_docs/python/tf/keras/losses/CategoricalCrossentropy
[4] https://www.tensorflow.org/api_docs/python/tf/keras/losses/categorical_crossentropy
[5] https://www.tensorflow.org/api_docs/python/tf/keras/losses/SparseCategoricalCrossentropy
[6] https://www.tensorflow.org/api_docs/python/tf/keras/losses/sparse_categorical_crossentropy
[7] https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html
[8] https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html
[9] https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
[10] https://www.tensorflow.org/api_docs/python/tf/keras/losses/KLD
[11] https://www.tensorflow.org/api_docs/python/tf/keras/losses/KLDivergence
[12] https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
[13] https://www.tensorflow.org/api_docs/python/tf/keras/losses/MAE
[14] https://www.tensorflow.org/api_docs/python/tf/keras/losses/MeanAbsoluteError
[15] https://www.tensorflow.org/api_docs/python/tf/keras/losses/MeanAbsolutePercentageError
[16] https://www.tensorflow.org/api_docs/python/tf/keras/losses/MAPE
[17] https://www.tensorflow.org/api_docs/python/tf/keras/losses/Huber
[18] https://pytorch.org/docs/stable/generated/torch.nn.L1Loss.html
[19] https://pytorch.org/docs/stable/nn.functional.html?highlight=loss#torch.nn.functional.l1_loss
[20] https://pytorch.org/docs/stable/generated/torch.nn.SmoothL1Loss.html
[21] https://pytorch.org/docs/stable/nn.functional.html?highlight=loss#torch.nn.functional.smooth_l1_loss
[22] https://www.tensorflow.org/api_docs/python/tf/keras/losses/MeanSquaredError
[23] https://www.tensorflow.org/api_docs/python/tf/keras/losses/MSE
[24] https://www.tensorflow.org/api_docs/python/tf/keras/losses/MeanSquaredLogarithmicError
[25] https://www.tensorflow.org/api_docs/python/tf/keras/losses/MSLE
[26] https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html
[27] https://pytorch.org/docs/stable/nn.functional.html?highlight=loss#torch.nn.functional.mse_loss
[28] https://www.tensorflow.org/api_docs/python/tf/keras/losses/CategoricalHinge
[29] https://www.tensorflow.org/api_docs/python/tf/keras/losses/categorical_hinge
[30] https://www.tensorflow.org/api_docs/python/tf/keras/losses/Hinge
[31] https://www.tensorflow.org/api_docs/python/tf/keras/losses/hinge
[32] https://www.tensorflow.org/api_docs/python/tf/keras/losses/SquaredHinge
[33] https://www.tensorflow.org/api_docs/python/tf/keras/losses/squared_hinge
[34] https://pytorch.org/docs/stable/generated/torch.nn.HingeEmbeddingLoss.html
[35] https://www.tensorflow.org/api_docs/python/tf/keras/losses/CosineSimilarity
[36] https://www.tensorflow.org/api_docs/python/tf/keras/losses/cosine_similarity
[37] https://pytorch.org/docs/stable/generated/torch.nn.CosineEmbeddingLoss.html