作者 | Marin Vlastelica
编译 | 蒋宝尚
(雷锋网(公众号:雷锋网)作品)目前,在计算机这个学科中有两个非常重要方向:一个是离散优化的经典算法-图算法,例如SAT求解器、整数规划求解器;另一个是近几年崛起的深度学习,它使得数据驱动的特征提取以及端到端体系结构的灵活设计成为可能。
那么能否将组合器与深度学习相结合?
ICLR 2020 spotlight 论文《Differentiation of Blackbox Combinatorial Solvers》探讨了这一问题。
论文下载地址:https://arxiv.org/abs/1912.02175
在论文中,作者试着将组合求解器无缝融入深度神经网络,并在魔兽争霸最短路径问题、最小损失完美匹配问题以及旅行商问题中进行了测试。测试结果显示,其组合求解器+深度学习的方法达到的效果比传统的方法要好。
另外,论文的一作Marin Vlastelica,在Medium上撰文介绍了这篇论文的主要思想,雷锋网 AI科技评论作了有删改的编译,以下是原文请欣赏~
机器学习的研究现状表明,基于深度学习的现代方法与传统的人工智能方法确实存在不一致的地方。深度学习在计算机视觉、强化学习、自然语言处理等领域的特征提取方面有着强大的功能。虽然如此,但其在组合泛化问题(combinatorial generalization)上一直是研究者所诟病。
例如,将地图作为输入从而在 Google Maps 上预测最快路线的最短路径的规划问题;(Min,Max)-Cut 问题、最小损失完美匹配问题(Min-Cost Perfect Matching)、旅行商问题、图匹配问题等等。如果单独解决上述每一个问题,我们有很多工具可以选择:你可以用C语言,可以使用更通用的 MIP(mixed integer programming)求解器。当然求解器需要考虑输入空间问题,毕竟它需要定义良好的结构化输入。虽然组合问题已经成为机器学习研究领域的关注点,但对此类问题的研究力度尚且不足。
这也不是说研究者不重视组合泛化问题,毕竟它仍然是智能系统的关键挑战之一。理想情况下,研究者能够以端对端方式,通过强大的函数逼近器(如神经网络)将丰富的特征提取与高效的组合求解器结合起来。这也正是论文《Differentiation of Blackbox Combinatorial Solvers》中所实现的,另外,这篇论文获得了很高的评审分数,并入选为 ICLR 2020 spotlight 论文。文章接下来的部分,并不是在试图改进求解器,而是要将函数逼近和现有求解器协同使用。
假设黑盒求解器(blackbox solver)是一个可以轻松插入深度学习的结构模块。
黑盒求解器的梯度
将连续输入到离散输出之间的映射作为求解器的方式,另外,连续输入可以是图边的权重,离散输出可以是最短路径、选定的图边。其中,映射的定义如下
求解器可以将最小化一些损失函数c(ω,y),这些损失函数可以是路径的长度。用公式这种优化问题表示如下:
上式中,w为神经网络的输出,也就是神经网络学习的某种表示,例如可以是图边权重的某个向量。在最短路径问题、旅行商问题中,ω可以用来作出正确的问题描述。优化问题的关键是最小化损失函数,现在的问题是损失函数是分段表示的,也就是说存在跳跃间断点。这意味着对于表示 ω,该函数的梯度几乎处处为 0,并且在跳跃间断点处,梯度尚未被定义。目前,利用求解器松弛(solver relaxation)的方法能够解决这个问题,但会损失最优性。论文中提出了一种不影响求解器最优性的方法。即对原始目标函数的分段处用仿射插值来定义,另外插值由超参数 λ 控制,如下图所示:
如上所示,函数图像的黑色部分是原函数给出的值,橙色部分是利用插值法给出的值。最小值没有变化。
当然,f的域是多维的。因此,对于同一个f的取值,可以有多个w相对应。也就是说输入的ω的集合是一个多面体,输出的f可以是相同的值。自然地,在 f 的域中有许多这样的多面体。超参数 λ 有效地通过扰动求解器输入 ω 来使多面体偏移。定义了分段仿射目标的插值器 g 将多面体的偏移边界与原始边界相连。
如下图所示,取值 f(y2) 的多面体边界偏移至了取值 f(y1) 处。这也直观地解释了为什么更倾向使用较大的超参数λ。偏移量必须足够大才能获得提供有用梯度的插值器g
首先,定义一个扰动优化问题的解决方案,其中扰动由超参数λ控制,公式如下:
如果假设损失函数c(ω,y)是y和ω之间的点积,则可以定义插值目标:
损失函数的线性度并不像乍一看那样有限制性。例如,在边选择问题中,损失函数要考虑所有边权重的和,具体事例参考旅行商问题和最短路径问题。
雷锋网注:如上图所示,插值随着超参数λ的变化而变化
算法
使用该方法,可以通过修改反向传播来计算梯度,从而消除经典组合求解器和深度学习之间的不一致性。
def forward(ctx, w_):
"""
ctx: Context for backward pass
w_: Estimated problem weights
"""
y_ = solver(w_)
# Save context for backward pass
ctx.w_ = w_
ctx.y_ = y_ return y_
在前向传播中,只需给嵌入求解器提供 ω,然后将解向前传递。此外,我们保存了 ω 和在前向传播中计算得到的解 y_。
def backward(ctx, grad):
"""
ctx: Context from forward pass
"""
w = ctx.w_ + lmda*grad # Calculate perturbed weights
y_lmda = solver(w)
return -(ctx.y_ – y_lmda)/lmda
在后向传递中,用超参数λ的反向传播梯度来扰动 ω,并取先前解与扰动问题解之间的差值
计算插值梯度的计算开销取决于求解器,额外的开销有两次,一次是在前向传播过程中调用的一次求解器,另一次是在后向传播过程中调用的一次求解器。
实验
为了验证该方法,设计了具有一定程度复杂度的合成任务进行验证。
另外,因为简单的监督学习方法无法泛化至没有见过的数据,所以在下面的任务中,已经证明了此方法对于组合泛化的必要性。
对于最短路径问题,测试任务为魔兽争霸,训练集包括《魔兽争霸 II》地图,任务目标为地图对应的最短路径问题。具体而言,测试集包含了未知的《魔兽争霸 II》地图。地图本身编码为K*K网格。卷积神经网络的输入是地图,输出地图是顶点的损失,然后将该损失作为求解器的输入。最后,求解器(Dijkstra 最短路径算法)以指示矩阵的形式在地图上输出最短路径。
在训练的开始,神经网络不知道如何为地图的图块分配正确的损失,但是使用组合求解器+深度学习能够得到正确的成本,从而找到正确的最短路径。下列直方图表明,相比于 ResNet 的传统监督训练方法,此方法的组合泛化能力更棒。
在最小损失完美匹配问题上,使用的数据集是MNIST,任务目标是输出 MNIST 数字组成网格的最小损失完美匹配。具体而言,在此问题上,选择的边应该让所有的顶点都能够恰好被包含一次,另外还能够让损失之和最小。另外,网格中的每个单元都包含一个 MNIST 数字,该数字是图中具备垂直和水平方向邻近点的一个节点。最后,边的损失由垂直向下或水平向右的两位数字决定。
求解器输出匹配中所选边的指示向量。右侧的匹配损失为 348(水平为 46 + 12,垂直为 27 + 45 + 40 + 67 + 78 + 33)。
在下面这张性能图上,我们可以清晰看到在神经网络中嵌入真实的完美匹配求解器能够达到更好的效果。
在旅行商问题中,训练数据集是国旗(即原始表示)和对应首都的最优旅行线路。神经网络的输出是各个国家首都的最佳旅行线路。神经网络在训练的过程,最重要的学习首都位置的隐表示。包含K个国家的训练示例如下图所示。
将各个国家的国旗输入卷积神经网络,然后网络输出最优旅行线路。
在下面的动画中,也可以看到神经网络训练期间各国首都在全球范围内的位置。
起初,位置是随机分布的,但经过训练后,神经网络不仅学习输出正确的TSP旅行线路,而且学习输出正确的表示,即各个首都的正确3D坐标。值得注意的是,这仅仅是通过在监督训练过程中使用 Hamming 距离损失,以及对网络输出使用 Gurobi 中的 MIP 实现的。
总结
实际上,已经证明在求解器损失函数的某些假设下,可以通过黑盒组合求解器传播梯度。这能够让传统有监督方法的标准神经网络架构实现的组合泛化能力。
深度学习+组合求解器的学习方法能够在一些需要组合推理的现实问题上得到广泛的应用。然而问题在于求解器损失的线性这一假设前提上,在此假设下我们究竟可以走多远?未来工作的重点以及问题在于我们能否学习到组合问题的潜在约束,例如 MIP 组合问题。
参考文献
Vlastelica, Marin, et al. “Differentiation of Blackbox Combinatorial Solvers” arXiv preprint arXiv:1912.02175 (2019). (http://bit.ly/35IowfE)
Rolínek, Michal, et al. “Optimizing Rank-based Metrics with Blackbox Differentiation.” arXiv preprint arXiv:1912.03500 (2019). (http://bit.ly/35EXIMN)
https://towardsdatascience.com/the-fusion-of-deep-learning-and-combinatorics-4d0112a74fa7
。
原创文章,作者:ItWorker,如若转载,请注明出处:https://blog.ytso.com/137666.html