跳转至

The Cram Method for Efficient Simultaneous Learning and Evaluation

讲者: Kosuke Imai
讨论人: Rui Song and Hengrui Cai - Q&A moderator: Michael Li
来源: OCIS (Online Causal Inference Seminar)
日期: 2024-04-02
主题: 因果推断
视频: https://youtu.be/DSAcBPsoPa4 · 幻灯片

本页据讲座录音的自动转写(ASR)生成。人名 / 术语 / 公式 / 具体的率与界可能被听错,关键处请对照视频或讲者论文核对。


一、这场报告在讲哪条工作线

这场报告的核心问题是:如何用同一份数据,同时学习一个决策/预测规则(policy),并给出这个具体学习规则的统计性能评估(point estimate + confidence interval)? 这个子方向通常被称为“simultaneous learning and evaluation”或“post-selection inference”中的一类——不是为算法选参数(如CV),而是为最终投入使用的那个规则做推断。

奠基与主流路线: - Sample splitting(数据分割):将数据分成训练集和测试集。优点是实现简单、评估无偏;缺点是数据利用不充分(训练集比例大则评估方差大,训练集比例小则学习质量差)。这是实践中无处不在的 baseline。 - Cross-validation(交叉验证):通过重复划分来更有效地利用数据。但它评估的是 ML 算法的平均性能(average performance of the algorithm),而不是你最终在完整数据上训练出的那一个具体规则。这一区别意味着:若直接把 CV 的均值和方差当作最终规则的性能估计,会系统性地低估不确定性(因为你忽略了最终规则本身也是由完整数据训练的随机变量)。此外CV需要多次重拟合算法,计算昂贵。 - 共作者之前的工作:讲者 Kosuke Imai 和他的合作者 Michael Lingzhi Li 之前就这个问题已有系列研究,但讲者未列出具体论文标题(幻灯片未提,转写中也未给)。

这场报告的贡献(Cram Method): 提出一个名为 Cram(“填鸭式”) 的通用框架。它同时克服了 sample splitting 和 CV 的缺陷: 1. 数据高效:用全部数据来学习最终规则,也用全部数据来评估这个规则(不像 spliting 要“割一块”做 evaluation)。 2. 计算高效:学习和评估通过数据的一次性顺序遍历完成(单次 pass),不重复拟合;支持 online learning 算法。 3. 评估的是学习规则,而非算法平均。

该方向当前的 frontier: - 如何扩展到非 i.i.d. 设置(bandit、reinforcement learning、时间序列)。讨论人 Rui Song 的工作(SAVE, Dream)以及讲者自己提到的 bandit / active learning 扩展都指向这个方向。 - 稳定性条件(stability condition)是该框架的核心假设,如何放松或自动满足它,是目前理论上的主要开放点。

这场报告站在“从 sample splitting → 一个更高效、同时也能做统计推断的替代方案”这个节点上。它不是设置全新的反事实识别问题,而是在已知可识别条件下,用巧妙的顺序划分结构提高学习与推断的效率


二、最小内核 / 一个最简例子

设置(与幻灯片一致): - 可观测数据:i.i.d. 样本 \( D_n = \{(X_i, D_i, Y_i)\}_{i=1}^n \)。 - \(X \in \mathcal{X} \subset \mathbb{R}^p\):预处理协变量(观测)。 - \(D \in \{0,1\}\):二值处理(观测)。 - \(Y = Y(D)\):观测到的结果(观测)。 - 潜在不可观测量:potential outcomes \(Y(1), Y(0)\)(个体层面不可同时观测)。 - 识别假设(强可忽略性)\( \{Y(1), Y(0)\} \perp\!\!\!\perp D \mid X \)(unconfoundedness)+ \( 0 < e(x) = P(D=1 \mid X=x) < 1 \)(overlap)。 - 待估目标(estimand): - Policy \(\pi(x) = P(D=1 \mid X=x) \in [0,1]\)(可为 deterministic 或 stochastic)。 - Policy value\( V(\pi) = \mathbb{E}[Y(1)\pi(X) + Y(0)(1-\pi(X))] \)。 - Policy value difference\( \Delta(\pi; \pi') = V(\pi) - V(\pi') \)。 - 学习目标:\( \hat\pi_T = \arg\max_{\pi \in \Pi} \hat V(\pi) \)(用全部数据学习)。 - 评价目标:报告 \( V(\hat\pi_T) \)(或与 baseline \(\pi_0\) 的差值)及其 uncertainty。

最简例子(d=1, 两个时间点): - 把数据随机分成 T=3 个大小相等的 batch:\(B_1, B_2, B_3\)。 - 初始 baseline policy \(\pi_0\)(例如“不治疗任何人”)。 - Iteration 1:用 \(B_1\) 通过学习算法 A 得到 \(\hat\pi_1\)。使用剩余批次 \(B_2 \cup B_3\) 评估 \(\Delta(\hat\pi_1; \pi_0)\),得到 \(\hat\Delta(\hat\pi_1; \pi_0)\)。 - Iteration 2:用 \(B_1 \cup B_2\) 学习 \(\hat\pi_2\)。使用剩余批次 \(B_3\) 评估 \(\Delta(\hat\pi_2; \hat\pi_1)\),得到 \(\hat\Delta(\hat\pi_2; \hat\pi_1)\)。 - 最终的 Cram estimator 为:

\[\hat\Delta(\hat\pi_3; \pi_0) = \hat\Delta(\hat\pi_1; \pi_0) + \hat\Delta(\hat\pi_2; \hat\pi_1)\]
注意我们从不评估最后一个迭代的更新(\(\Delta(\hat\pi_3; \hat\pi_2)\)),因为已经没有剩余数据了。我们把这一项当作“missing term”,并依靠稳定性条件保证它很小。 - 每个 \(\hat\Delta(\hat\pi_t; \hat\pi_{t-1})\) 采用 IPW(inverse probability weighting)估计:
\[\hat\Delta(\hat\pi_t; \hat\pi_{t-1}) = \frac{1}{|B_{t+1} \cup \dots \cup B_T|} \sum_{j=t+1}^T \sum_{i \in B_j} \left[\frac{Y_i D_i}{e(X_i)} - \frac{Y_i (1-D_i)}{1-e(X_i)}\right] (\hat\pi_t(X_i) - \hat\pi_{t-1}(X_i))\]
这是一个对该 batch 中观测到的个体 i,用 IPW 估计个体处理效应,再乘以政策差异,然后平均。

核心思想: - 数据被用于反复训练→测试→训练→测试,而不是一次性地训练一次、测试一次。 - 早期迭代的政策变化大,但恰好有很多剩余数据用于评估这些变化。后期政策几乎稳定(变化很小),此时虽然评估数据很少,但因为被评估的变化本身很小,误差依然可控。 - 最后通过加总所有已评估的变化,得到从 initial policy 到 final policy 的总差值,等价于我们想要的 \(\Delta(\hat\pi_T; \pi_0)\)


三、报告主体:讲者讲了什么

3.1 引言与动机 [0:00-0:06] - 讲者指出实践中普遍使用数据驱动算法做决策/预测,但关键问题是用同一份数据同时学习规则并评估它。 - 传统方案:Sample splitting(80%训练、20%测试)数据利用率低;Cross-validation 数据高效但不评估学习规则(只评估算法平均性能),且计算昂贵。 - 这引出 Cram 方法的动机:要同时做到数据高效、计算高效、评估学习规则本身

3.2 Cram 方法概览 [0:06-0:14] - 讲者用 [0:09:00] 处幻灯片里的“Cramming at Glance”图演示过程:将数据切为 T 个 batch,按顺序迭代训练和测试。 - 关键分解 [0:12:00]:\( \Delta(\hat\pi_T; \pi_0) = \sum_{t=1}^T \Delta(\hat\pi_t; \hat\pi_{t-1}) \),其中 \(\hat\pi_0 = \pi_0\)。实际估计只累加到 T-1,略去最后一项(missing term)。 - [0:13:30] 讲者强调“最终政策是完整数据学到的”,因此评估的 uncertainty 包含了学习过程本身的变异,这正是与 CV 的本质区别。

3.3 算法与估计量 [0:14-0:20] - 给出算法伪代码(幻灯片7): - 随机分 batch; - 迭代:\(\hat\pi_t = \mathcal{A}(\cup_{j=1}^t B_j)\)(学习),\(\hat\Delta(\hat\pi_t; \hat\pi_{t-1})\) 用剩余 batch 评估; - 最后累加。 - 每个 \(\hat\Delta\) 使用简单 IPW 估计量(幻灯片8),但讲者说也可以换成 doubly-robust 或其他无偏估计。 - [0:18:30] 提出另一种“列视角”表示:每个 batch j 被用于评估所有先前迭代的政策差(乘上适当的权重)。这利用了 condition on previous batches 后的独立结构,方便分析。

3.4 稳定性条件 [0:20-0:24] - [0:21:00] 讲者引入 Assumption 2 (Stability Condition): > For some \(\delta>0, R_1>0, K_0>0\), for all \(t\ge R_1\), \(t^{1+\delta} Q_t \le K_0\) almost surely, > where \(Q_t = \mathbb{E}_X[|\hat\pi_t(X) - \hat\pi_{t-1}(X)|]\)。 - 关键直觉:政策变化必须以至少 \(O(t^{-1-\delta})\) 的速率衰减,与用于评估该变化的样本量(约 \(T-t\) 个 batch)匹配,从而保证估计误差可忽略。 (幻灯片10) - [0:22:00] 讲者提出一种通用 stabilizer:在每轮以概率 \(p_t = \min\{C t^{-1-\delta}, 1\}\) 接受新政策 \(\tilde\pi_t\),否则沿用旧政策 \(\hat\pi_{t-1}\)。这使得任何学习算法都能满足稳定性条件(幻灯片11)。 - 实际上,只要算法在 80% 的数据上不需要被 stabilizer 拒绝,就能保证较好的学习效果。

3.5 主要理论结果 [0:24-0:30] - Theorem 1 (L1 Consistency) [0:24:00]:\(\mathbb{E}[|\hat\Delta(\hat\pi_T;\pi_0) - \Delta(\hat\pi_T;\pi_0)|] \to 0\)\(T\to\infty\)。误差分解为 missing term(可忽略)+ 累积估计误差(用条件方差和 Jensen 不等式控制,最终为 \(O(T^{-\delta/2})\) 量级)。 - 条件要求:稳定性条件 + 有界矩 + overlap。 - Theorem 2 (Asymptotic Normality) [0:26:00]: > \(\sqrt{T} \frac{\hat\Delta(\hat\pi_T;\pi_0) - \Delta(\hat\pi_T;\pi_0)}{v_T} \xrightarrow{d} N(0,1)\), > 其中 \(v_T^2 = T\sum_{j=2}^T \mathbb{V}(\hat\Gamma_j(T) \mid \mathcal{H}_{j-1})\)。 - 关键创新:CLT 针对的是学习规则(随机变量),不是固定参数。证明利用条件独立性加上 Berry-Esseen 界,结合“下三角区域(小样本部分)渐近可忽略”的论证(幻灯片15-16)。 - Variance estimation [0:26:30]:给出一个基于残差方差的一致估计量 \(\hat v_T^2\),可通过 IPW 形式的样本残差计算(幻灯片17)。

3.6 直觉与注意事项 [0:30-0:34] - [0:30:00] “为什么 Cram 能在数据效率上超过 sample splitting?”:因为评估从政策变化最大的早期就开始了,恰好那时有最多数据可用于评估。政策稳定后,评估数据变少但被评估的变化也很小。 - Batch size 的选择:更小的 batch 使数据利用更充分,但政策变化可能不平稳。推荐实践中用约 5% 的 batch size。 - Batch size 也可以变化(目前无理论结果)。

3.7 仿真研究 [0:34-0:38] - 使用 ACIC 2016 数据集的 77 个 DGP(CATE 估计任务)比较 Cram 和 80/20 sample splitting。 - Policy value 提升:Cram 通常有约 0-2% 的改进(更多训练数据的贡献)。 - 标准误差降低:Cram 的 SE 平均降低 30-40%,这是最显著的优势。 - 偏差:Cram 略大,但相对 policy value 约为 ±2% 量级,不影响覆盖率。 - 95% CI 覆盖率:Cram 稳定在 0.90-1.00,与 sample splitting 相当。 - [0:37:00] 不同样本规模下,Cram 的标准误差相对改善保持恒定的 30-40%,说明该优势不会随样本量增长而消失。 - 仿真中的学习器:S-learner、M-learner(ridge regression & neural nets)、因果森林。Causal Forest 因为本身就非常稳定,Cram 的额外改进最小。

3.8 实证应用 [0:38-0:41] - 数据:晚期前列腺癌合成雌激素临床试验(之前已分析存在异质性治疗效果)。 - Cram vs Sample Splitting: - 估计的治疗比例相近(≈57%); - 估计的 Policy Value:Cram 7.77 vs Splitting 3.90(翻倍提升,但标准误大); - 标准误:Cram 4.42 vs Splitting 6.65(降幅约 33%,与仿真一致); - 90% CI 在 Cram 下不包含 0,而在 Splitting 下包含 0。 - [0:40:30] 从学习到的政策来看,Cram 学到的政策倾向于给第四期癌症患者(stage 4) 更高概率的治疗,这与当前临床指南一致。

3.9 总结与未来工作 [0:41-0:43] - Cram 是一种通用方法论,不是仅限政策学习。 - 已提出的扩展: - Bandit(利用顺序结构); - Active Learning(样本效率关键); - 更一般的机器学习预测与分类; - Cramming cross-validation(将 Cram 的思想与 CV 结合)。 - [0:42:00] 讲者承认目前“仍在理解自己的方法”,并希望得到反馈。


四、对应论文与开放问题

(a) 对应论文

  • 标题:基于幻灯片末尾“Paper available at”的链接,以及报告标题,对应论文为: > The Cram Method for Efficient Simultaneous Learning and Evaluation > Kosuke Imai, Zeyang Jia, Michael Lingzhi Li > arXiv: 2403.07031 (https://arxiv.org/pdf/2403.07031.pdf)

    合作者信息在幻灯片第1页与转写 [0:01:50-0:02:00] 得到确认。

  • 讨论人 Rui Song 提到的相关工作:

    • SAVE (Cai, Song et al., JSSP 2022):用于无限 Horizon 强化学习的序贯划分与评估方法。
    • Dream (Cai, Song et al.):用于 Bandit 的 Doubly Robust Interval Estimation。
    • 讲者表示自己之前未见过这些工作,但认为思想非常接近。

(b) 报告留下的开放问题(每条扎根于转写具体位置)

  1. 最优 Batch Size 的理论推导 [0:32:00]:“we do not yet know optimal batch size / batch size can vary too”。讲者推荐 5%,但缺乏理论指导。是一个明确的开放问题。

  2. 放松稳定性条件 [0:22:30]:讲者提到正在探索将 uniform 条件 relax 到只要求极限行为的版本。此外 [0:23:30] 提到“如何选择 batch size 作为样本量函数”也是一个子问题。

  3. 扩展到非 i.i.d. 设置:讨论人 Rui Song 和讲者均明确指出 [0:41:30, 0:47:00-0:50:00] 将 Cram 扩展到 Bandit 或强化学习时,数据不再是 i.i.d.,稳定性条件必须重新定义(因为政策变化速度受探索-利用权衡影响)。

  4. 加权平均能否进一步提高效率:讨论人 Rui Song [0:50:00] 指出 Cram 目前对每个迭代的评估取(平均)\(1/(T-t)\),但也许可以用最优权重(如 inverse variance weighting)来提高最终估计量的效率。讲者承认在在线设置中已有类似思想(Dream paper),但未在 Cram 的框架下探索。

  5. 与 SAVE/Dream 方法的精确理论对比:讨论中 Michael Li [0:56:00] 提到 Cram 的条件与 SAVE 的条件(tail condition, martingale CLT)有本质不同,直接比较需要单独的工作。


Maintained by 陈星宇 · Homepage · Source on GitHub

评论