博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
SWATS算法剖析(自动切换adam与sgd)
阅读量:5288 次
发布时间:2019-06-14

本文共 1266 字,大约阅读时间需要 4 分钟。

SWATS算法剖析(自动切换adam与sgd)

27 人赞同了该文章

SWATS是ICLR在2018的高分论文,提出的一种自动由Adam切换为SGD而实现更好的泛化性能的方法。

论文名为Improving Generalization Performance by Switching from Adam to SGD,下载地址为:。

作者指出,基于历史梯度平方的滑动平均值的如adam等算法并不能收敛到最优解,因此在泛化误差上可能要比SGD等方法差,因此提出了一种转换机制,试图让算法自动在经过一定轮次的adam学习后,转而由SGD去执行接下来的操作。

算法本身思想很简单,就是采用adam这种无需操心learning rate的方法,在开始阶段进行梯度下降,但是在学习到一定阶段后,由SGD接管。这里前面的部分与常规的adam实现区别不大,重要的是在切换到sgd后,这个更新的learning rate如何计算。 整个算法步骤流程如下:

 

 

熟悉adam的应该能熟悉蓝色的部分,这个就是adam的原生实现过程。

作者比较trick的地方就是14行到24行这一部分。这一部分作者做了部分推导,[公式]作为最后的切换learning rate。

算法的整个实现逻辑并不复杂,这里列出自己实现时遇到的一些问题。

填坑 & 问题

  1. 在上面的算法流程第12行,有个[公式],这个在整个流程中未介绍如何实现,本人阅读论文后,发现应该是学习率衰减的设计。一如很多深度学习策略一样,这里可以设置经过若干轮迭代后,学习率降为原来的1/N。在论文中,作者使用了在150轮后,将学习速率减少10倍。即[公式]
  2. 上面说了[公式]的更新,我们通过公式推导,其实能发现[公式][公式]有一定的关系,自己代码实现的版本,发现切换的时机很大程度上和[公式]有关,因为切换涉及到第17行的一个比较过程,[公式][公式]本身都与[公式]相关,当[公式]降一个量级时,[公式]|本身也会更接近[公式]。其有些类似正比关系,因此一般都是在经过一定轮次的衰减后,才能触发SGD切换时机。这一点目前本人实现验证是这样,未深入推理。
  3. 这个[公式]还有个坑,就是实现该算法,开始不太清楚这个k到底指的是epoch,还是指的经历的batch数量。最后按照常规学习率衰减应该是按照epoch来算的,因此推测其k应该为epoch。
  4. 还有和大坑是[公式]作为学习率,在切换到SGD后应一直不变,该值为标量,因此应该如常用eta等学习率一样,为正值,因此需要在17行加个约束,即[公式]。(该场景难以复现,之前有次更新发现不设置为正值时,导致切换sgd后准确度大减)

总结

通过若干的对比,该论文变相增加了一些超参数,所以实际使用有待商榷。自己的数据集上经常就在还未满足切换条件就已经收敛了。 目前已做了相应的实现,放在scalaML中,位置为,使用见。最后想要查看切换过程的话,建议将early_stop设置为false,然后将学习率衰减系数设置低一点。 代码目前仅支持二分类。

转载于:https://www.cnblogs.com/think90/p/11515242.html

你可能感兴趣的文章
MFC网络编程TCP/IP的服务器与客户端代码
查看>>
线程池的用法Android
查看>>
Java学习路线-知乎
查看>>
python-study-06
查看>>
IDEA配置maven中央库
查看>>
C# 基础
查看>>
mybatis进阶-5resultMap总结
查看>>
【OOAD】OOP的主要特征
查看>>
MapReduce进行本地的单词统计
查看>>
HTTP 状态码
查看>>
【转】详解硬盘MBR
查看>>
bashrc 文件命令
查看>>
hdu4271 Find Black Hand 2012长春网络赛E题 最短编辑距离
查看>>
V7000初始化
查看>>
animate.css的使用
查看>>
Struts2 注释类型
查看>>
JSP中EL表达式语言不能使用的解决方法
查看>>
做XH2.54杜邦线材料-导线
查看>>
如何刻录cd音乐
查看>>
Codeforces Round #318(Div 1) 573A, 573B,573C
查看>>