雷锋网 AI 科技评论按:这篇博客来自 Jetpac(现被谷歌收购) CTO、苹果毕业生、TensorFlow 团队成员 Pete Warden。文中讨论了一个很容易被机器学习领域的研究人员们忽略的问题:你是否真的清楚数据对模型表现有多大影响,同时你又有没有付出适当的精力在改善你的数据上呢?已经为生产环境开发过模型的研究人员相信已经对这件事足够重视,不过也不妨重温一下其中的重要思路。
原文标题「Why you need to improve your training data, and how to do it」,雷锋网 AI 科技评论全文编译如下。
Andrej Karpathy(char-rnn 作者,https://twitter.com/karpathy)在「Train AI」(https://www.figure-eight.com/train-ai/)会议上演讲的时候放了这张幻灯片,我非常喜欢它!他的幻灯片完美地抓住了深度学习研究和工业生产之间的差异(读博的时候大多数熬夜都是在调模型,但在企业里大多数熬夜是在处理数据)。学术论文的关注点几乎都在具有创新性、性能有所提升的模型上,而它们使用的数据集往往都是一小部分的公开数据集。然而,据我所知,在实际应用中使用深度学习技术的人往往会花大量的时间分析考虑他们的训练数据。
研究人员如此关注模型架构的理由有很多,但这同时也意味着,能帮助指导如何把机器学习用在生产环境中的资料非常少。为了解决这个问题,我在 Train AI 大会上的演讲题目为「训练数据无法解释的有效性」,并且在这篇博文中,我想对这个题目进行进一步的拓展,解释为什么数据如此重要,介绍一些改进数据的实用技巧。
作为我工作的一部分,我与许多研究人员和产品团队进行了密切的合作。当我看到他们专注于模型构建过程中的数据改进工作,并且由此取得了大量的成果,我对于数据改进的威力便深信不疑。在大多数应用中,使用深度学习技术的最大阻碍是在现实世界中获得足够高的准确率,而改进训练数据集是我所见到的最快的能够提升准确率的途径。即使你受限于延迟、存储空间等因素,在特定的模型上提升准确率也可以让你能够通过使用较小的模型架构在这些性能指标上进行折衷。
在这里,我不能公开分享我对于工业生产系统的观察,但是我为大家提供了一个开源的示例来演示相同的模式。去年,我创建了一个TensorFlow 环境下的「简单语音识别示例」(https://www.tensorflow.org/tutorials/audio_recognition)。结果表明,并没有现有的数据集可以被很容易地用于训练模型。在大量的志愿者的慷慨的帮助下,通过使用 AIY 团队协助我开发的「开放式语音记录网站」(https://aiyprojects.withgoogle.com/open_speech_recording),我收集到了 60,000 条一秒长的人类说短单词的音频片段。由此得到的模型是可用的,但是仍然没有达到我想要的准确度。为了看看这种情况在多大程度上是由于我自己作为一个模型设计者存在的不足,我使用相同的数据集举办了一个Kaggle 竞赛(https://www.kaggle.com/c/tensorflow-speech-recognition-challenge/)。参赛者的模型性能比我的简单模型要好得多,但尽管有许多不同的方法,多支队伍的准确率都止步于 91% 左右。这告诉我数据中存在着本质上错误的问题。而实际上参赛者们也发现了很多错误,比如不正确的标签或被截断的音频。这促使我开始动手发布一个修复了参赛者发现的问题的新数据集,并且提供更多的样本。
我查看了一下错误度量结果,从而了解该模型对于哪些词语存在的问题最多。结果显示,「其它」类别(语音能够被识别,但是相应单词在模型有限的词汇表中无法找到)尤其容易出错。为了解决这个问题,我加大了我们捕获的不同单词的数量,从而为训练数据提供更大的多样性。
因为 Kaggle 参赛者报告了标注错误的问题,我众包了一个额外的验证过程,要求人们听每一个音频片段,并确保它与预期的标签相匹配。由于参赛者在Kaggle 竞赛中还发现了一些几乎无声或者被截断了的音频文件,我还编写了一个实用程序来进行一些简单的音频分析(https://github.com/petewarden/extract_loudest_section),并且自动删除质量特别差的样本。最终,得益于更多的志愿者和一些有偿众包接包者的努力,尽管删除了质量不佳的文件,我将语音样本的数量增大到了超过 100,000 份的规模。
为了帮助他人使用该数据集(并且从我的错误中吸取教训!)我将所有有关的内容和更新后的准确率的结果写到了一篇论文中(https://arxiv.org/abs/1804.03209)中。最重要的结论是,在完全不改变模型结构或测试数据的情况下,首位准确率提高了超过 4 个百分点,从 85.4% 提高到了 89.7%。这是一个显著的提升,并且当人们在安卓或树莓派上尝试应用 demo 时,结果也令人满意得多。尽管我知道我现在使用的并非最先进的模型,但是我坚信如果我把时间都花在模型架构的调整上,我将无法取得如此大的提升。
在生产环境中,我看到这种处理方法一次又一次地产生了很好的的结果,但如果你想做同样的事情,却可能无从下手。你可以从我在语音数据上使用的技术中得到一些启发,但是接下来,我将向你介绍一些我认为非常实用的方法。
这似乎是显而易见的,但你首先应该做的便是随机地浏览你将要开始使用的训练数据。你需要将文件复制到本地机器上,然后花几个小时预览它们。如果你要处理图像,你可以使用 MacOS 的 finder 这样的工具来滚动浏览缩略视图,这样就能很快地查看数以千计的图片;对于音频文件来说,则可以使用 finder 播放预览;或者对于文本文件来说,将随机片段转存到你的终端中。我并没有花费足够的时间对第一版语音控制系统进行上述操作,而这也正是为什么一旦 Kaggle 竞赛的参赛者开始处理数据就发现了如此之多的问题。
我一直觉得这个处理过程有点傻,但是做完后我从未后悔过。每当我完成这些数据处理工作,我都会发现一些对数据至关重要的东西,无论是不同种类样本数量的失衡、损坏的数据(例如将 PNG 文件的扩展名标注为 JPG 文件)、错误的标签,或者仅仅是令人感到惊讶的数据组合。通过仔细的检查,Tom White 在 ImageNet 中取得了许多有趣的发现。例如:
标签「sunglass」(太阳镜),实际上指的是一种古代的用于放大太阳光线的装置(下图)
而 ImageNet 中的「sunglass」和「sunglasses」(太阳眼镜)高度混淆(左侧图像类别被标注为「sunglass」、右侧类别为「sunglasses」,摊手)
标签为「garbage truck」(垃圾车)的分类中,有一张美女的照片
标签为「cloak」(斗篷)的图像似乎对不死女有偏见
除此之外,Andrej 手动地对 ImageNet 中的图片进行分类的工作(http://karpathy.github.io/2014/09/02/what-i-learned-from-competing-against-a-convnet-on-imagenet/)也教会了我很多关于这个数据集的知识,包括如何分辨出所有不同的狗的种类,甚至是对一个人来说。
你将采取什么样的行动取决于你发现了什么,但是在做其它任何的数据清理工作之前,你总是需要做这种检查,因为对于数据集中内容直观的了解可以帮助你在接下来的步骤中做出正确的决策。
不要花太长的时间去选择恰当的模型。如果你要做图像分类,请查看 AutoML(https://cloud.google.com/automl/),否则可以看看 TensorFlow 的模型仓库(https://github.com/tensorflow/models/)或者 Fast.AI 收集的样例(http://www.fast.ai/),直接找到和你的产品解决的问题相似的模型。重要的是尽可能快的开始迭代,这样你就可以尽早并且经常地让真实用户来测试你的模型。你总是可以,而且可能看到更好的实验结果,但首先你得保证数据没有问题。深度学习仍然遵循「无用输入得到无用输出」(garbage in,garbage out)的基本计算法则,所以即使是最好的模型的性能也会受到你的训练数据中的缺陷的限制。通过选择一个模型并且对其进行测试,你就能够理解这些这些缺陷是什么,并开始改进它们。
为了进一步加快你的模型的迭代速度,你可以试着从一个已经在大规模的现有的数据集上预训练好的模型开始,通过迁移学习使用你收集到的(可能小的多的)数据集对它进行调优。这通常会比仅仅在你较小的数据集上进行训练得到的结果好得多,因此你可以快速地对如何调整你的数据收集策略有一个清晰的认识。最重要的是,你可以根据你可与你在你的数据收集过程中考虑结果中的反馈,从而在学习的过程中进行调整,而不是仅仅在训练之前将数据收集所谓一个单独的阶段运行。
在建立学术研究模型和工业生产模型之间最大的区别在于,学术研究通常在一开始就定义了一个明确的问题的声明,但实际用用的需求则存在于用户的脑海中,并且只能随着时间的推移慢慢变得清晰明了。例如,对于 Jetpac 来说,我们希望找到好的图片,将其展示在城市的自助旅行指南中。
这是一个极端的例子,但是它说明标注过程在很大程度上取决于应用的需求。对于大多数生产中的用例来说,找出模型需要回答的正确的问题需要花费很长一段时间,而这对于正确地解决问题是十分关键的。如果你正在用你的模型回答错误的问题,你将永远无法在这个糟糕的基础上建立一个良好的用户体验。
我发现唯一能够判断你所问的问题是否正确的方法就是模拟一个你的应用程序,而不是建立一个有人类参与决策的机器学习模型。由于有人类在背后参与决策,这种方法优势被称为「Wizard-of-Oz-ing」。而对于 Jetpac 来说,我们让使用者手动地从一些旅游指南样例中选择图片,而不是训练一个模型,然后从测试用户的反馈来调整我们用于选择图片的标准。一旦我们能够可靠地从训练中获得正反馈,我们就可以将我们设计的图片选择规则迁移到一个标注指南中,从而获得数以百万计的图片作为训练集。接着,这些数据被用来训练出能够预测数十亿图片的标签的模型,但它的 DNA(核心思想)来自我们设计的原始的手动选择图片的规则。
在上述的 Jetpac 的例子中,我们用于训练模型的图像和我们希望将模型应用到的图片的来源相同(大部分来自 Facebook 和 Instagram),但是我发现的一个常见问题是,训练数据集与模型使用的输入数据的一些关键性的差异最终会体现在生产结果中。
举例而言,我经常会看到一些团队试图在无人机或机器人上使用 ImageNet 数据训练出的模型时遇到问题。之所以会出现这种情况是因为 ImageNet 中的数据大多为人为拍摄的照片,而这些照片存在着很多共同的特性。这些图片是用手机或静态照相机拍摄的,它们使用中性镜头,拍摄的高度大约与头部平行,在日光或人造光线下拍摄,需要标记的物体位于图片中央并位于前景中。而机器人和无人机使用视频摄像机,通常采用拥有大视场角的镜头,其拍摄的位置不是在地面上就是在高空中,通常拍摄的光线都比较差,并且由于没有智能化地对图像进行定位的机制,通常只能对图片进行裁剪。这些差异意味着,如果你只是利用 ImageNet 中的图片训练模型并将其部署到上述的某台设备上,那么你将得到较低的准确率。
你所使用的训练数据和模型最终的输入数据还可能有一些细微的差异。不妨想象你正在建造一个能识别野生动物的相机,并且使用来自世界各地的动物数据集来训练它。如果你只打算将它部署在婆罗洲的丛林中,那么图片应该被标注为企鹅的概率会极其的低。而如果训练数据中包含南极的照片,那么模型将会有很有可能将其他动物误认为企鹅,模型整体的准确率会低于不使用这部分训练数据时的准确率。
有很多方法可以让你根据已知的先验知识(例如,在丛林中,大幅度降低图片被标注为企鹅的概率)来校准你的结果,但使用能够反映产品真实场景的训练集会更加有效。我发现最好的方法就是始终使用直接从实际应用程序中获取的数据,这与我上面提到的 Wizard of Oz 方法能很好地结合在一起。也就是说,在训练过程中使用人做出决策可以改为对你的初始数据集进行标注,即使收集到的标签数量非常少,它们也可以反映真实的使用情况,并且也应该能够满足进行迁移学习的初步的实验的基本要求。
当我研究语音控制系统的例子时,我最常看到的报告之一就是训练期间的混淆矩阵。下面是一个在控制台中显示的例子:
[[258 0 0 0 0 0 0 0 0 0 0 0]
[ 7 6 26 94 7 49 1 15 40 2 0 11]
[ 10 1 107 80 13 22 0 13 10 1 0 4]
[ 1 3 16 163 6 48 0 5 10 1 0 17]
[ 15 1 17 114 55 13 0 9 22 5 0 9]
[ 1 1 6 97 3 87 1 12 46 0 0 10]
[ 8 6 86 84 13 24 1 9 9 1 0 6]
[ 9 3 32 112 9 26 1 36 19 0 0 9]
[ 8 2 12 94 9 52 0 6 72 0 0 2]
[ 16 1 39 74 29 42 0 6 37 9 0 3]
[ 15 6 17 71 50 37 0 6 32 2 1 9]
[ 11 1 6 151 5 42 0 8 16 0 0 20]]
这可能看起来有点让人摸不着头脑,但实际上它只是一个显示网络错误分类的细节的表格。下面为大家展示一个更加美观的版本:
表中的每一行代表一组与实际标签相同的样本,每列显示预测出的标签结果的数量。例如,高亮显示的行表示所有实际上是无声的音频样本,如果你从左至右阅读则一行,可以看到标签预测的结果是完全正确的,因为每个预测标签都落在将样本预测为无声音频的列中。这告诉我们,该模型非常善于正确识别真正的无声音频样本,不存在误判。如果我们从一整列的角度来看这个表格,第一列显示有多少音频片段被预测为无声样本,我们可以看到一些实际上是单词的音频片段被误认为是无声的,显然这其中有大量的误判。这样的结果对我们来说非常有用,因为它让我更加仔细地观察那些被错误地归类为是无声样本的音频片段,而这些片段中又很多都是在相当安静的环境下录音的。这帮助我通过删除音量较低的音频片段来提高数据的质量,而如果没有混淆矩阵的线索,我将不知道该如何处理它。
几乎所有对预测结果的总结都可能是有用的,但是我认为混淆矩阵是一个很好的折衷方案,它提供的信息比仅仅给出准确率的数字更多,同时也不会包含太多细节,让我无法处理。在训练过程中,观察数字的变化也很有用,因为它可以告诉你模型正在努力学习什么类别,并可以让你知道在清理和扩展数据集时需要注意哪些领域。
我最喜欢的一种理解我的网络如何解读训练数据的方式是可视化聚类。TensorBoard 为这种探索方法提供了很好的支持,虽然它经常被用于查看词嵌入,但我发现它几乎适用于任何像嵌入技术一样工作的的网络层。例如,图像分类网络在最后的全连接或softmax 单元之前通常具有可以被用做嵌入表示的倒数第二层(这也正是像「TensorFlow for Poets」(https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/#0)这样的简单的迁移学习示例的工作原理)。这些并不是严格意义上的嵌入,因为我们并没有在训练过程中设法确保在真正的嵌入中具有你所希望得到的理想的空间属性,但对它们的向量进行聚类确实会得到一些有趣的结果。
举一个实际的例子,之前一个和我一同工作的团队在他们的图像分类模型中困惑于对某些动物的高错误率。他们使用可视化聚类技术来观察他们的训练数据是如何分布到不同类别中去的,当他们看到「美洲豹」时,他们清晰地看到发现数据被分成两个彼此之间存在一定距离的不同的组
上图即为它们所看到的聚类结果。一旦将每个聚类的图片展示出来,我们可以很明显地看到有很多「捷豹」牌汽车被错误地标注为美洲豹。一旦他们知道了这一问题,他们就能够检查标注过程,并且意识到工作人员的指导和用户界面是令人困惑的。有了这些信息,他们就能够改进(人类)标注者的训练过程并修复工具中存在的问题,将所有汽车图像从美洲豹的类别中删除,使得模型在该类别上取得更高的分类准确率。
通过让你深入了解你的训练集中的内容,聚类提供了与仅仅观察数据相同的好处,但网络实际上是通过根据自己的学习理解将输入分组来指导您的探索。作为人类,我们非常善于从视觉上发现异常情况,所以将我们的直觉和计算机处理大量输入的能力相结合为追踪数据集质量问题提供了一个高可扩展的解决方案。关于如何使用 TensorBoard 来完成这样的工作的完整教程超出了本文的范围,不多赘述。但如果你真的想要提高模型性能,我强烈建议你熟悉这个工具。
持续收集数据
我从来没有见过收集更多的数据而不能提高模型准确性的情况,结果表明,有很多研究都支持我的这一经验。
该图来自「谷歌证明数据为王,初创公司们被泼上了一盆冰水」,展示了即使训练集的规模增长到包含数以亿计的样本,图像分类的模型准确率也在不断提高。Facebook 最近进行了更深入的探索,它们使用数十亿带标签的 Instagram 图像在 ImageNet 图像分类任务上获得了新的准确率最高的记录(「发美照时打上 #,还能帮Facebook提升图片识别率哟」)。这表明,即使对拥有大规模、高质量数据集的任务来说,增加训练集的大小仍然可以提高模型的性能。
这意味着,只要有任何用户可以从更高的模型准确率中受益,你就需要一个可以持续改进数据集的策略。如果可以的话,你可以寻找具有创造性的方法来利用甚至是十分微弱的信号来获取更大的数据集。Facebook 使用 Instagram 标签就是一个很好的例子。另一种方法是提高标注过程的智能化程度,例如通过将模型的初始版本的标签预测结果提供给标注人员,以便他们可以做出更快的决策。这种方法的风险是可能在早期引入一些偏差,但实际上我们所获得的好处往往超过这种风险。聘请更多的人标注新的训练数据来解决这个问题,通常也是一项划算的投资行为,但是对于那些对这类支出没有预算的传统的组织来说,这可能十分困难。如果你运营的是非营利性组织,让你的支持者通过某种公共工具更方便地自愿提供数据,这可能是在不增加开支的情况下加大数据集规模的好方法。
当然,对于任何组织来说,最优的解决方案都是应该有一种产品,它可以在使用时自然地生成更多的带标签数据。尽管我不会过于关注这个想法,因为它在很多真实的用例中都不适用,毕竟人们只是想尽快得到答案,而不希望参与到复杂的标注过程中来。如果你经营一家初创公司,这是一个很好的投资项目,因为它就像是一个改进模型的永动机,但在清理或增强你所拥有的数据时难免会涉及到一些单位成本。所以经济学家最终经常会选择一种比真正免费的方案看起来更加便宜一点的商业众包版本。
模型错误对应用程序用户造成的影响几乎总是大于损失函数可以捕获的影响。你应该提前考虑可能的最坏的结果,并尝试设计一个模型的底线以避免它们发生。这可能只是一个因为误报的成本太高而不想让模型去预测的类别的黑名单,或者你可能有一套简单的算法规则,以确保所采取的行动不会超过某些已经设定好的边界参数。例如,你可能会维护一个你不希望文本生成器输出的脏话词表,即便它们确实存在于训练集中。因为它们出现在你的产品中是很不恰当的。
究竟会得到怎样不好的结果在事先并不总是那么明显,所以从现实世界中的错误中吸取教训是至关重要的。要做的好一点,最简单的方法之一就是在一旦你有一个半成品的时候,就使用错误报告。当人们使用你的应用程序时,如果她们得到了它们不想要的结果,他们就可以很容易地告诉你。如果可能的话,你需要获得完整的模型输入,但当它们是敏感数据时,仅仅知道什么是不好的输出是什么同样有助于指导你的调查。这些类别可被用于选择收集更多数据的来源,以及您应该去了解其哪些类别的当前标签质量。一旦对模型进行了新的调整,除了正常的测试集之外,还应该对之前产生不良结果的输入进行单独的测试。考虑到单个指标永远无法完全捕捉到人们关心的所有内容,这个错例图片库类似于回归测试,并且为您提供了一种可以用来跟踪你改善用户体验的方法。通过观察过去引发强烈反应的一小部分例子,你已经得到了一些独立的证据来表明你实际上正在为你的用户提供更好的服务。如果因为数据过于敏感而无法获取模型的输入数据,可以使用内部测试或内部实验来确定哪些输入会产生这些错误,然后替换回归数据集中的那些输入。
我希望我已经设法说服你在数据的准备和处理上花费更多时间,并且向你提供了一些关于如何改进它的想法。目前这个领域还没有得到应有的重视,我甚至觉得我在这篇文章中所讲的也都只是皮毛,所以我感谢每一个与我分享他们的研究策略的人,我希望未来我能听到更多的关于你们已经成功应用的方法的消息。我认为会有越来越多的组织建立工程师团队致力于数据集的改进,而不是仅仅让机器学习研究人员来推动这个领域的研究。我期待看到整个领域能够得益于此而取得进展。我常常惊讶于模型即使在具有严重缺陷的数据集上也能很好地工作,所以我迫不及待地想看看,随着我们对于数据的改进,我们能够做些什么!
via petewarden.com,雷锋网 AI 科技评论编译