《从离散域DA到图域DA》
来给大家介绍一下我们ICLR 2022上的新工作,这个也是子昊同学 @shsjxzh 第一个ML的工作(撒花)。个人觉得,这个工作有意思的地方在于,它是第一个把传统domain adaptation(DA)的范式从离散域推广到图域,然后同时拿到了理论的保障和性能的提高。先放下论文的链接:“Graph-Relational Domain Adaptation(GRDA)”,http://wanghao.in/paper/ICLR22_GRDA.pdf
传统DA v.s. 图DA:借用下我们讲CIDA的知乎帖子的图,传统的DA,一般都是从一个(或几个)domain,adapt到另一个(或几个)domain,如下图:
而我们这个工作思考的问题是:如果domain之间存在一个描述domain之间关系的域图(domain graph),我们能不能把domain adaptation做得更好,以及有没有理论保障。比如下图所示,我们有15个domain,以及他们之间的domain graph。如果我们希望从右边6个source domain,adapt到左边的9个target domain,应该怎么做才好呢?
在解决这个问题之前,我们先来思考下,传统的DA方法,有什么缺陷。
传统的方法:处理这类问题,大家比较喜欢用Adversarial Domain Adaptation的方式。基本思路就是,首先把不同domain的数据,通过一个编码器(encoder)来提取feature,然后通过对抗训练(adversarial training)的手段,来迫使源域(source domain)和目标域(target domain)的feature分布完全对齐(重合)。为啥要让他们重合呢?因为重合后,源域和目标域的feature就可以使用一个预测器(predictor)来做预测。
传统方法的问题:那么这些传统的adversarial DA方法有什么问题呢?问题就在于,它们把各个domain当成相互独立的,从而无视了domain之间的关系。这样的话,它们在学encoder的时候,就会盲目地把所有不同domain的feature强制完全对齐。这样做是有问题的,因为有的domain之间其实联系并不大,强行对齐它们反而会降低预测任务的性能。
比如下面的例子,如果我们要为纽约(NY)训练一个预测模型(比如天气预报),把宾州(PA)的模型adapt到纽约(NY)应该可以达到不错的效果,因为这两个州地理位置相近;但是如果我们把加州(CA)的模型adapt到纽约(NY),可能效果就会很差。
引入domain graph:读到这里,可能大家已经意识到了,在这个问题里,我们刚好可以利用刚刚提出的domain graph来描述各个州(每个州是一个domain)之间的关系。比如下图就是一个domain graph,它描述了美国东北边15个州(15个domain)之间的近邻关系(adjacency)。
方法:在我们定义了domain graph之后,方法其实就非常简单自然了,如下图。我们只需要对传统的adversarial DA方法做一下简单的改动:(1)传统的方法直接把data x作为encoder的输入,而我们把domain index u以及domain graph作为encoder的输入,(2)传统的方法让discriminator对domain index进行分类,而我们让discriminator直接重构(reconstruct)出domain graph。
我们把这个新的方法叫做graph-relational domain adaptation(GRDA),然后把GRDA里面这个新提出的discriminator叫做graph discriminator。下面的一个图更加细致地比较了下传统的discriminator和我们的graph discriminator的区别。
理论:方法很简单,那么有没有什么有趣而实用的理论性质呢?因为我们用的是adversarial training,本质上是在求一个minimax game的均衡点(equilibrium)。在传统的DA方法上,因为discriminator做的是分类,我们可以很自然地证明,这个minimax game的均衡点就是会完全对齐所有domain。那么有意思的问题来了:我们用了graph discriminator后,这个minimax game的平衡点会有什么性质呢?换句话说,我们训练这个GRDA,最后会收敛到什么样子呢?GRDA也会完全对齐所有domain吗?
神奇的是,我们可以证明(如下图),在任何domain graph的情况下,当GRDA训练到最优时是可以保证不同domain的feature会根据domain graph来对齐的(Theorem 1),而不是强行让所有domain完全对齐。而且这个GRDA训练,不会影响到predictor的预测准确度(Theorem 2)。
更有意思的是,这里我们还能证明,传统的DA方法,其实是我们GRDA的一个特例。这个特例其实非常直观:传统的DA方法(完全对齐所有domain)会等价于当GRDA的domain graph是全连接图(fully-connected graph or clique)时的情况(Corollary 1)。
那么问题来了,如果GRDA的domain graph不是全连接图时,会发生什么事呢?我们在论文中也证明了,当GRDA的domain graph不是全连接图时,传统DA方法的解(即“完全对齐所有domain”)依然会是GRDA的最优解(均衡点)之一,但是,GRDA会有很多其他的最优解(均衡点)。也就是说,GRDA其实是对传统DA方法的relaxation。
一些有趣的domain graph特例:我们的上面的Theorem 1给出了GRDA在任意domain graph下的最优解。接下来,我们重点看一些有趣的domain graph特例。
(1)全连接图:正如上面所说,domain graph是全连接图(如上面左图)时,GRDA会等价于传统的DA。
(2)星形图:当domain graph是星形(如上面中图)时,如果GRDA达到最优解,那么中间domain(紫色)encoding分布p(e)会是外围所有domain的分布的平均,这个其实也非常直观。
(3)链状图:当domain graph是链状(如上面右图)时,如果GRDA达到最优解,那么只有相邻的domain的encoding分布p(e)之间会有直接的影响。具体地讲,我们可以证明:当且仅当下面这个式子对于任意e和e’都成立时,GRDA达到最优解。注意下面的p_i(e)表示的是domain i的encoding分布,而没有下标的p(e)表示的则是所有domain的p_i(e)的平均。
实验结果之Toy Dataset:作为评测的第一步,我们构造了一个15个domain的toy dataset及其对应的domain graph(如下图的左边),我们把它叫做DG-15。从下面图的右边,我们可以看到,GRDA的accuracy可以大幅超过其他的方法,特别是其他方法在离source domain比较远(从domain graph的角度)的target domain的准确率并不是很高,但是GRDA却能够保持较高的准确率。
我们还构造了一个类似的、有60个domain的toy dataset,叫做DG-60,也做了类似的实验。整体的准确度比较如下。
实验结果之气温预测:沿着我们上文说的天气预报的例子,我们构建了一个美国大陆48个州的气温数据集,叫做TPT-48,并在这个数据集上评测了各种方法。如下图(注意domain graph),为了全面评测,我们构建了2类任务:
(1)E->W:以东部的州为source domain(下图左边的黑色圆圈),以西部的州为target domain(下图左边的白色圆圈)。
(2)N->S:以北方的州为source domain(下图右边的黑色圆圈),以南方的州为target domain(下图右边的白色圆圈)。
大家可能发现了,我们还把target domain按照距离source domain的距离分为3层(Level-1,Level-2,和Level-3 Target Domain),这是为了更全面地评测各种方法在各个层次target domain的准确率。结果如下图,注意因为气温预测是一个回归(regression)任务,所以我们用的metric是MSE(越低越好)。我们可以大致看出来,对于所有的方法,距离source domain越远的target domain,预测得越不准,这个是在意料之中的。而我们的GRDA基本可以稳定达到最好的平均效果。
写在最后:熟悉的同学可能可以看出来,这个GRDA其实是跟我们ICML’20的”Continuously Indexed Domain Adaptation”(CIDA,http://wanghao.in/paper/ICML20_CIDA.pdf)是一脉相承的。
希望大家看了之后能够有所启发,没有启发的话,不是子昊同学这个工作做的不好,而是我这个帖子写得不好,所以也请轻拍:)
Paper: https://arxiv.org/abs/2202.03628 or http://wanghao.in/paper/ICLR22_GRDA.pdf
YouTube Video: https://www.youtube.com/watch?v=oNM5hZGVv34
Bilibili Video: https://www.bilibili.com/video/BV14L4y177Br?spm_id_from=333.999.0.0
ICLR Link: https://iclr.cc/Conferences/2022/Schedule?showEvent=7145
GitHub Link (still re-organizing the code): https://github.com/Wang-ML-Lab/GRDA
来源:知乎 www.zhihu.com
作者:王灏
【知乎日报】千万用户的选择,做朋友圈里的新鲜事分享大牛。
点击下载
此问题还有 39 个回答,查看全部。
延伸阅读:
ICLR 2021 有什么值得关注的投稿?
本文转自: http://www.zhihu.com/question/490962362/answer/2443450582?utm_campaign=rss&utm_medium=rss&utm_source=rss&utm_content=title
本站仅做收录,版权归原作者所有。