跳转至

Towards Causal Foundation Model: on Duality between Causal Inference and Attention

讲者: Chao Ma
讨论人: Jiaqi Zhang
来源: OCIS (Online Causal Inference Seminar)
日期: 2024-04-09
主题: 因果推断
视频: https://youtu.be/cnz17Q6g6Lw

本页据讲座录音的自动转写(ASR)生成。人名 / 术语 / 公式 / 具体的率与界可能被听错,关键处请对照视频或讲者论文核对。


一、这场报告在讲哪条工作线

这条工作线试图回答一个根本问题:能否构建一个“因果基础模型”(Causal Foundation Model, CFM),即一个预先训练好的、能够在零样本下对任意新的表格型观测数据集进行定量因果效应估计的单一模型。 这与用大规模语言模型(LLM)进行“定性/语义”因果推理不同,其目标是产出可量化的、可操作的因果效应值(如平均处理效应ATE),并具备泛化到未见因果机制的能力。

  • 子方向背景:因果推断与机器学习交叉领域长期以来面临两个方向的困难。一是“单数据集因果估计”方法(如逆概率加权IPW、Doubly Robust、Causal Forest、DragonNet等)在给定一个数据集时往往表现出色,但每个新数据集都需要从头重训练或重新拟合估计量【37:00-37:15】,在计算上无法支持高频/在线决策;二是大型基础模型(如LLM)在因果任务上表现不佳,缺乏定量能力,且其因果性能没有随模型规模扩展而显著提升(没有“因果缩放定律”)【13:00-13:32】。这条工作线试图同时克服上述两个缺陷:一方面利用深度学习架构的泛化潜力,另一方面严格限制其训练目标以保障因果可识别性。

  • 这条路线的奠基与主流

  • 经典协变量平衡:报告明确将自身工作定位为对“最优平衡”(Optimal Balancing)这一经典思想的应用【19:30-20:15】。该方向由一系列工作奠基,例如Hainmueller (2012)的熵平衡,Imai & Ratkovic (2014)的协变量平衡倾向得分(CBPS),以及更近期的基于核方法的平衡权重(如Kallus, 2020的凸/最优平衡)。
  • 因果基础模型先驱:报告引用了近期对LLM在因果推理任务上的基准测试(Kiciman et al., 2023/2024? 转写音似,需查证),并指出其在定量预测上的不足。报告本身的工作是作为“因果基础模型”的第一个具体技术步骤——即将因果估计与注意力机制建立等价的数学对偶,从而允许使用Transformer架构来“执行”因果推断。
  • 最新进展:报告提及了团队内的后续工作——“FIB: A Fixed-Point Approach for Causal Foundation Models”(正在ICML 2024评审中)【40:00-40:13】。FIB将本工作的对偶思想从潜在结果框架扩展到结构因果模型(SCM),能处理反事实预测和因果发现,并据报告介绍,它提供了“目前最通用的反事实分布可识别性结果”【40:50-41:05】。

  • 这场报告的站位:它并不宣称已经实现了一个规模化、多任务的因果基础模型。它展示的是实现此目标的第一块“工程积木”:在一个非常具体且简化的场景下(ATE估计,二值处理,潜在结果与特征由再生核希尔伯特空间RKHS建模),严谨证明了一个Transformer注意力头在特定损失(hinge + regularization)下的最优解,恰好等价于最优协变量平衡权重的解。这个等价性(对偶性)是整个CFM项目的研究基础。

二、最小内核 / 一个最简例子

核心想法:报告发现,寻找最优的无偏处理效应估计量这个问题,在特定条件下,完全等价于训练一个线性注意力头去预测某个特定的目标(例如,处理变量W本身)。这个注意力头学到的“值”(value)恰好就是最优的平衡权重。

符号与模型

  1. 可观测数据\(D = \{ (X_i, T_i, Y_i) \}_{i=1}^n\),其中:

    • \(X_i \in \mathbb{R}^d\):协变量。
    • \(T_i \in \{0, 1\}\):二值处理。
    • \(Y_i \in \mathbb{R}\):观测到的结果。
  2. 目标估计量(Estimand):平均处理效应 (ATE) \(\tau = \mathbb{E}[Y(1) - Y(0)]\),其中\(Y(1), Y(0)\)是潜在结果。

  3. 潜在不可观测量:每个单元的潜在结果\(Y(1)\)\(Y(0)\)之一观测不到。

  4. 可识别假设:无混杂(Unconfoundedness / Ignorability)\(\{Y(1), Y(0)\} \perp T | X\),以及重叠性\(0 < P(T=1|X) < 1\)

  5. 估计量形式:报告使用一种基于协变量平衡的ATE估计量: \(\hat{\tau}(\alpha) = \sum_{i: T_i=1} \alpha_i Y_i - \sum_{i: T_i=0} \alpha_i Y_i\). 其中\(\alpha_i\)是每个单元的权重,且\(\sum_{i: T_i=1} \alpha_i = \sum_{i: T_i=0} \alpha_i = 1\). 这个估计量的核心任务是找到一组最优权重\(\alpha^*\),使得加权后的协变量分布(即\(\sum_i \alpha_i X_i\) 等)在处理组和对照组之间达到平衡,从而近似一个随机化实验。

  6. 函数类假设:报告假设未知的潜在结果函数\(\mu_0(x) = \mathbb{E}[Y(0) | X=x]\)属于某个再生核希尔伯特空间 (RKHS) \(\mathcal{H}\),并带有特征映射\(\phi(x)\)和核函数\(K\)【29:20-29:30】。

最简例子(\(d=1\), \(T\in\{0,1\}\), \(n=3\)

  • 数据:三个样本点: \((X, T, Y) = (0.2, 0, 5), (0.8, 0, 8), (0.5, 1, 10)\)
  • 问题:由于处理组\(Y=10\)\(X=0.5\)与对照组\(X\)分布不同,直接用均值差\(10 - (5+8)/2 = -1.5\)不是ATE的一个无偏估计。我们需要找到权重\(\alpha_1, \alpha_2\)(对应对照组)和\(\alpha_3\)(对应处理组)使得加权协变量平衡(例如\(\alpha_1 * 0.2 + \alpha_2 * 0.8 \approx \alpha_3 * 0.5\)),且估计量\(\hat{\tau} = \alpha_3 * 10 - (\alpha_1 * 5 + \alpha_2 * 8)\)

  • 报告的核心声明:寻找最优\(\alpha\)的问题,等价于训练一个注意力模型。具体地:

    • 我们构建一个(单层)注意力头,它把整个数据集\(D\)当作“上下文”,把数据集中的每一个样本当作一个“token”。
    • 这个注意力头被训练去做一个分类任务:根据其他所有样本的信息,来预测某个目标样本的处理值\(T_i\)(或者报告中的\(W_i\),即\(-1/1编码的\)T_i$)。
    • 训练损失是 Hinge损失(类似于SVM)加上一个正则化项,鼓励权重之和为1等约束。
    • 报告的数学结果:在这个上下文注意力头中,每个token对应的“值”(Value)向量,在训练达到全局最优时,会恰好等于最优权重\(\alpha_i\)乘以某个与\(Y_i\)\(X_i\)相关的常数【31:35-31:50】。换句话说,注意力机制内部存储的“值”就是最优平衡权重。
  • 对偶性的直观理解:这个问题可以看作是一个双层优化: 外层(原始问题,因果推断):寻找最优权重\(\alpha\),以最小化最坏情况(在所有可能的RKHS函数\(\mu\)下)的ATE估计误差。 内层(对偶问题,注意力):对于给定的权重\(\alpha\),寻找最佳预测模型(注意力头)来预测\(W\),预测误差由Hinge损失衡量。 报告证明了这两个问题是彼此的“对偶”,且在最优解下,外层的最优\(\alpha^*\)就是内层注意力模型产生的Value向量。

三、报告主体:讲者讲了什么

0:00 [00:00 - 01:30] 开场与介绍。主持人介绍讲者Chao Ma(微软研究院剑桥)和讨论者Jiaqi Zhang(MIT博士生,曾是Ma的实习生)。

0:15 [01:30 - 06:00] 问题背景与动机。 - 核心论点:当前的基础模型(如LLM)在因果推理上有根本性缺陷。它们是“关联学习”,不是“因果学习”。 - 挑战1:LLM只能给出定性的因果回答(如“打折会增加收入”),而因果模型需要输出定量的、可操作的结果(如“打折10%使收入增加$20,000”)。 - 挑战2:LLM无法处理未知/全新的干预(如推出前所未有的销售策略),因为它们依赖于训练数据中的案例。因果模型应能通过潜在结果模型或结构因果模型来泛化。 - 挑战3:缺乏因果缩放定律。模型越大(从随机基线到GPT-4),其因果推理能力并没有显著提升。这表明“扩大规模”不是解决因果问题的方法。

0:15 [06:00 - 18:00] 提案:因果基础模型。 - 报告提出了一个愿景:训练一个全新的基础模型,直接在所有可用的定量数据集上训练,目标是做出正确的因果预测。该模型应该是零样本的,即对于任何一个新的观测数据集,它只需一次前向传播,就能在不需要重新训练的情况下给出因果效应估计。 - 思想转变:不把因果知识编码在语言里,而是编码在预测任务模型架构中。模型学习的是“如何从观测数据中恢复因果效应”的通用模式。

0:18 [18:00 - 24:10] 关键技术:对偶性 (Duality) 的高层介绍。 - 核心结果:报告的核心发现是,学习一个基于注意力的模型(即优化注意力头的参数),在数学上等价于解决一个“最优因果预测”问题。这两个问题是彼此的“对偶”。 - 两个基础知识: 1. 最优平衡:一种比IPW更通用的协变量平衡方法。目标是找到一组权重\(\alpha\),让观察数据的分布与随机实验的分布对齐,从而在加权后可以简单地计算ATE。 2. 注意力机制:报告将Transformer中的注意力机制描述为一个“软性选择/加权”过程。通常用于语言模型,根据Query和Key的相似度,对Value进行加权求和。

0:24 [24:10 - 34:00] 核心方法:CAusal INference with Attention (CINA)。 - 更精确的数学表述:讲者展示了如何在数学上将对偶性具体化。 1. 损失函数:训练一个注意力模型,其损失由两部分组成:一个基于Hinge Loss的“分类/预测项”(预测处理变量\(W\))和一个“正则项”【27:00-27:20】。 2. 证明路线[27:50-34:00]: - 假设潜在结果函数属于RKHS,定义ATE估计风险的最坏情况最小化问题(假设我们不知道真实的潜在结果,但在所有满足RKHS范数约束的函数中最小化误差)【29:00-29:30】。 - 这个minimax问题可以转化为一个核SVM形式的优化问题【30:00-30:20】。在这个问题中,对偶变量恰好是平衡权重\(\alpha\)。 - 核SVM问题可以进一步转化为另一个等效的SVM原问题,这个原问题的形式正是:优化一个注意力模型的参数,其中Hinge Loss是分类损失,\(V\) (value vector) 是由\(\alpha\)和输入\(X\)计算得出的【31:30-32:00】。 3. 关键公式:Value tensor \(V_i\) 在最优解时,正比于 \(\alpha_i \cdot (2W_i - 1) \cdot \phi(X_i)\) 等形式,\(K\)是一个特征映射的积分或中间层。这意味着,一旦模型训练好,我们可以直接从训练好的注意力头的Value中“读出”最优平衡权重。 4. 方法命名:CINA(Causal Inference with Attention)。它是一个使用Transformer架构的模型,在大量合成数据集上预训练。当输入新的观测数据时,进行一次前向传播,模型输出一个权重向量\(\hat{\alpha}\),用这个权重直接加权计算ATE,无需任何后续模型拟合。

0:34 [34:00 - 39:00] 实验验证。 - 实验设置:生成大量(高达5000+)的合成因果任务。每个任务由一个随机生成的因果图、一个非线性结构化因果模型(SCM)定义。在不同的图上生成观测数据。预训练模型(CINA)被要求对所有训练任务中的数据点进行“去混杂”(预测权重)。然后在未见过的1000个新合成任务上进行零样本评估。实验还包括从合成数据预训练直接迁移到真实世界ACIC(Atlantic Causal Inference Conference)数据集上的零样本测试。 - 主要结论: - 零样本性能优越:在合成数据和真实数据上,CINA的ATE估计误差(RMSE等)显著优于所有传统方法(包括DragonNet、TARNet等)。需要注意的是,所有传统方法都是在测试数据上单独重新训练或重新拟合的,而CINA是零样本【37:00-37:30】。 - 泛化能力强:CINA在从合成数据到真实数据(ACIC)的零样本泛化中取得了最好成绩,即使传统方法是在ACIC数据上专门调优的【38:00-38:20】。 - 推理速度快:CINA的零样本推理速度(单次前向传播)远远快于需要基于新数据重新训练的传统方法【38:25-38:55】。 - 有监督 vs 无监督:报告澄清,训练可以不依赖ATE的ground truth(纯无监督/自监督,符号ZS),但如果有部分实验数据提供ground truth,性能可以进一步提升(有监督版本ZSS)【44:00-45:10】。

0:39 [39:00 - 42:15] 讨论与扩展。 - 讨论者 (Jiaqi Zhang) 的主要观点: 1. CINA如何工作:将表格数据视为序列,用注意力让每个单元(token)“关注”所有其他单元,从其他单元的上下文信息中学习自己的平衡权重。这利用了“自注意力即SVM”的观察和“最优平衡即核SVM”的观察。 2. 关于假设:讨论了未观测混杂(unobserved confounders)的可行性。即使有未观测混杂\(Z\),如果观测到的协变量\(X\)\(T\)能提供对\(Z\)的充分信息(近似oracle),最优平衡方法依然有效,这与因果表示学习(causal representation learning)的近期趋势相似。 3. 未来方向: - 数据扩展:除了用新方法处理表格数据,是否也可以如Geneformer那样,将表格数据转化为类似语言的表示,再利用现成的LLM架构? - 与人类偏好对齐(RLHF):在因果任务中,人类反馈如何用于模型更新(例如,让专家评估ATE估计的可信度)? - 数据来源:规模化所需的数据从哪来?是更好的合成模拟?还是利用非公开的真实世界数据集?还是利用像单细胞数据这样的新领域数据? - 讲者回应:对于假设,应该视其为一种归纳偏差。在模型的原始形式中,若假设完全满足且有监督信号的辅助,模型能更好学习。随着数据量增加,模型可能会自动发现超越原始假设的因果模式。

0:45 [42:15 - 45:00] Q&A与收尾。 - 问答:回答了关于“动态因果”(时间序列/面板数据)【1:02:06-1:02:40】、“任务特异性”(是否不同暴露变量需要不同模型)【1:03:08-1:04:25】等问题。讲者承认当前模型聚焦于ATE估计,是任务特化的,但正在通过FIB等项目进行扩展。 - 结尾:感谢听众与合作者。

四、对应论文与开放问题

  • 对应论文(基于转写推断,字幕可能有误):

    • 核心工作:Chao Ma, et al. "Towards Causal Foundation Model: on Duality between Causal Inference and Attention". 此文应与MSR Causality Team的预印本或非正式发布版本对应。具体arXiv ID查证。讲者提及论文下载链接在幻灯片上,但转写中未提及具体ID。
    • 相关扩展"FIB: A Fixed-Point Approach for Causal Foundation Models"。讲者称其在ICML 2024审查中【40:10-40:13】。此文将CINA的对偶思想从潜在结果框架扩展到结构因果模型(SCM),支持反事实和因果发现。会议室成员Jiaqi Zhang是此文的共同作者。
  • 开放问题(扎根于转写中的特定时间点):

    1. 假设的鲁棒性:报告的核心证明假设了无混杂和RKHS。在真实世界中,未观测混杂普遍存在。讨论者Jiaqi Zhang【52:00-53:00】指出CINA的训练数据合成过程包含了不同程度的未观测混杂,但模型如何应对未知的未观测混杂仍是一个开放问题。转写依据:[52:00-53:00] "there was also this previous work, by Banana and Claus... which says if you have a latent confounder... you're fine using optimal balancing as long as you have an approximate oracle..." 这表明假设放松是已知研究前沿。
    2. 任务特异性:听众提问“因果基础模型是否会太任务特定(task-specific)?是否每个暴露变量都需要不同模型?”【1:03:08-1:04:25】。讲者承认当前模型(CINA)只能估计ATE,是任务特定的。未来的扩展方向是处理个体处理效应(CATE)、反事实和更一般的非结构化任务。转写依据:[1:03:50-1:04:15] "currently only addresses one type of task... it only solves ATE estimation".
    3. 非表格数据:讨论者Jiaqi Zhang指出CINA当前仅适用于表格型数据。对图像、视频等非结构化数据,目前的方法是将它们转换为数值特征或嵌入向量,但尚未验证。如何将因果与注意力对偶性扩展到更复杂的多模态数据是一个开放问题。转写依据:[54:20-55:00] "so far the experiments that we've seen in this paper are basically all tabular data sets".
    4. 零样本与inductive bias:报告提出的模型在训练时(在合成数据上)不需要ATE ground truth。但听众提问【1:03:08-1:04:25】暗示,零样本泛化能力可能依赖于训练数据中因果机制的分布。如何在没有ground truth的情况下,确保模型学到的是真正的因果结构,而不是表面上的虚假相关,是一个核心挑战。转写依据:[42:30-45:00] 关于合成数据与无监督训练的矛盾讨论(Jiaqi Zhang提问,讲者Chao Ma回应)。

Maintained by 陈星宇 · Homepage · Source on GitHub

评论