雷锋网(公众号:雷锋网)按:本文为AI研习社编译的技术博客,原标题 A Comprehensive guide to Fine-tuning Deep Learning Models in Keras (Part I),作者为 Felix Yu 。
翻译 | 杨东旭 校对 | 孟凡 整理 | MY
这篇文章中,我们将对实践中的微调做一个全面的概述,微调是深度学习中常用的方法。
我将借鉴自己的经验,列出微调背后的基本原理,所涉及的技术,及最后也是最重要的,在本文第二部分中将分步详尽阐述如何在 Keras 中对卷积神经网络模型进行微调。
首先,为什么对模型进行微调?
当我们得到一个深度学习任务时,例如,一个涉及在图像数据集上训练卷积神经网络(Covnet)的任务,我们的第一直觉将是从头开始训练网络。然而,在实践中,像 Covnet 这样的深度神经网络具有大量的参数,通常在百万数量级。在一个小的数据集(小于参数数量)上训练一个 Covnet,会极大的影响网络的泛化能力,通常会导致过拟合。
因此,更常见的是微调一个在大数据集上已经训练好的模型,就像 ImageNet(120 万的标注图像),然后在我们的小数据集上继续训练(即运行反向传播)。假如我们的数据集与原始数据集(例如 ImageNet)在上下文中没有明显的不同,则预训练模型已经具有了处理我们自己的分类问题相应的学习特征。
何时微调模型?
一般来说,如果我们的数据集在上下文中与预训练模型的训练数据集没有明显不同,我们应该进行微调。像 ImageNet 这样大而多样的数据集上的预训练网络,在网络前几层可以捕获到像曲线和边缘这类通用特征,这些特征对于大多数分类问题都是相关且有用的。
当然,如果我们的数据集代表一些非常具体的领域,例如医学图像或中文手写字符,并且找不到这个领域的预训练网络,那么我们应该考虑从头开始训练网络。
另一个问题是,如果我们的数据集很小,那么在小数据集上微调预先训练的网络可能会导致过拟合,特别是如果网络的最后几层是全连接层,就像 VGG 网络的情况。根据我的经验,如果我们有几千个原始样本,并实现了常见的数据增强策略(翻译,旋转,翻转等),微调通常会使我们得到更好的结果。
如果我们的数据集非常小,比如少于一千个样本,则更好的方法是在全连接的层之前将中间层的输出作为特征(瓶颈特征)并在网络的顶部训练线性分类器(例如 SVM)。SVM 特别擅长在小型数据集上绘制决策边界。
微调技术
以下是一些实现微调通用的指导原则:
1. 常用的做法是截断预训练网络的最后一层(softmax 层),并将其替换为与我们自己的问题相关的新 softmax 层。例如,ImageNet 上经过预先训练的网络带有 1000 个类别的 softmax 层。
如果我们的任务是 10 个类别的分类,则网络的新 softmax 层将是 10 个类别而不是 1000 个类别。然后,我们在网络上运行反向传播来微调预训练的权重。确保执行交叉验证,以便网络具有很好的泛化能力。
2. 使用较小的学习率去训练网络。因为我们期望预先训练的权重相比随机初始化权重要好很多,所以不希望过快和过多地扭曲这些权重。通常的做法是使此刻的初始学习率比从头训练的初始学习率小 10 倍。
3. 还有一个常用的做法是冻结预训练网络的前几层的权重。这是因为前几层捕获了与我们的新问题相关的曲线和边缘等通用特征。我们希望保持这些权重的完整。相反,我们将在后面的层中专注于学习数据集中的特殊特征。
在哪里找到预训练网络?
这要取决于深度学习框架。对于像 Caffe,Keras,TensorFlow,Torch,MxNet 等流行的框架,他们各自的贡献者通常会保留已实现的最先进 Covnet 模型(VGG,Inception,ResNet 等)的列表和在 ImageNet 或 CIFAR 等常见数据集上的预训练权重。
找到这些预训练模型的最好方法是用 google 搜索特定的模型和框架。但是,为了方便您的搜索过程,我将在流行框架上的常用预训练 Covnet 模型放在一个列表中。
-
Caffe
-
Model Zoo -为第三方贡献者分享预训练 caffe 模型的平台
-
Keras
-
Keras Application – 实现最先进的 Convnet 模型,如 VGG16 / 19,googleNetNet,Inception V3 和 ResNet
-
TensorFlow
-
VGG16
-
Inception V3
-
ResNet
-
Torch
-
LoadCaffe – 维护一个流行模型的列表,如 AlexNet 和 VGG。从 Caffe 移植的权重
-
MxNet
-
MxNet Model Gallery – 维护预训练的 Inception-BN(V2)和 Inception V3。
在 Keras 中微调
……
想要继续阅读,请移步至我们的AI研习社社区:https://club.leiphone.com/page/TextTranslation/719
更多精彩内容尽在 AI 研习社。
不同领域包括计算机视觉,语音语义,区块链,自动驾驶,数据挖掘,智能控制,编程语言等每日更新。
雷锋网雷锋网
雷锋网雷锋网
。
原创文章,作者:ItWorker,如若转载,请注明出处:https://blog.ytso.com/industrynews/134347.html