在NLP的应用中,分类算法是最常用的算法,而分类算法最常用的损失函数是交叉熵。为什么我们会用交叉熵作为分类算法的标配呢?在模型的训练过程中,最小化交叉熵意味着模型学到了什么?为什么不用均方误差(MSE)作为分类算法的损失函数呢?如果MSE不好,那其他的损失函数,比如合页损失(Hing Loss)呢?

交叉熵最小话到底在学什么?

博文中介绍了信息熵、KL散度、交叉熵,从信息论的角度解释了为什么可以用交叉熵来作为分类算法的损失函数。
在机器学习中,实际上有三个概率分布:真实数据的分布、训练数据的分布、模型学习到的分布。
KL散度和交叉熵的关系:KL(A||B) = H(A, B) - S(A)
上式中,A、B表示两个分布,AB的KL散度就等于A与B分布的交叉熵减去A分布的熵。当A固定时,KL散度和交叉熵是等价的。

在模型的训练过程中,使模型的分布不断拟合训练数据的分布。如何评估这两个分布之间的差距呢?就可以用KL散度:

KL(training||model) = H(training, model) - S(training)

由于训练数据是固定的了,所以训练数据的熵是定值,即S(training)是固定的,此时KL散度等价与交叉熵,所以交叉熵可以衡量training分布与model分布的不同。

值得注意的是,如果model学习到的概率分布与训练数据的分布完全相同时,就意味着过拟合,所以要一个高斯分布的误差的存在,代表模型的泛化能力。

为什么不用MSE作为分类算法的损失函数呢?

博文中指出,如果使用MSE时,在接近0或1的位置,偏导数非常接近0,导致学习的会非常慢。

Logo

GitCode 天启AI是一款由 GitCode 团队打造的智能助手,基于先进的LLM(大语言模型)与多智能体 Agent 技术构建,致力于为用户提供高效、智能、多模态的创作与开发支持。它不仅支持自然语言对话,还具备处理文件、生成 PPT、撰写分析报告、开发 Web 应用等多项能力,真正做到“一句话,让 Al帮你完成复杂任务”。

更多推荐