郴州市网站建设_网站建设公司_Java_seo优化
2025/12/22 19:19:00 网站建设 项目流程

论文信息

论文标题:Graph Rationalization with Environment-based Augmentations
论文作者:王宇杰、于奎、张玉宏、曹付源、梁吉业
论文来源:ICLR'2024
发布时间:2024
论文地址:link
论文代码:link

2 研究动机&&研究问题

2.1 研究动机

1. 图属性预测的领域需求与数据瓶颈

  • 领域重要性:图神经网络(GNN)在化学信息学(分子属性预测)、材料信息学(聚合物属性预测)等领域应用广泛。例如,分子属性预测可辅助药物研发,聚合物属性预测(如氧气渗透性、玻璃化转变温度)能加速高性能材料(如气体分离膜、耐高温聚合物)的发现,解决工程与环境领域的关键挑战(如 biodegradability、高温稳定性需求)。

  • 数据痛点:这类任务的核心瓶颈是数据集规模小—— 分子基准数据集通常仅 1000-10000 个图,聚合物数据集更小(如 O₂Perm 仅 595 个样本)。小数据导致 GNN 模型易过拟合、泛化能力差,难以稳定学习图结构与属性间的因果关系。

2. 现有图合理化方法的缺陷

  • 图合理化的核心目标:识别对预测结果起决定性作用的 “关键子图(Rationale)”,剩余部分为 “环境子图(Environment)”。关键子图需同时满足:① 单独用于预测时性能接近原图;② 具备可解释性(符合领域知识);③ 不受环境子图中噪声的干扰。

  • 现有方法的两大问题

    1. 样本不足导致关键子图识别不准确:现有方法(如图池化、DIR)依赖真实数据学习关键子图,但小数据集缺乏足够多样性,导致模型难以区分 “因果关键子图” 与 “非因果环境噪声”。例如,DIR 通过分布干预生成多分布数据,但在小样本场景下仍无法有效学习最优关键子图。

    2. 显式子图处理导致计算复杂度过高:现有方法(如 DIR)需显式解码关键子图和环境子图(如通过边缘掩码选择 Top-K 边构建子图),再重新编码子图表征,不仅计算成本高,还易丢失原图的节点上下文信息,导致关键子图结构不连贯、可解释性差。

3. 数据增强在图任务中的潜力与空白

  • 数据增强是解决小样本问题的有效手段,但现有图数据增强方法(如 DropEdge、GraphCrop)多为启发式结构修改(如裁剪子图、删除边),未针对图合理化设计,无法直接辅助关键子图的识别与分离。

  • 核心灵感:环境子图可视为 “自然噪声”,若通过 “环境替换” 生成虚拟样本,迫使模型聚焦于关键子图(因果部分)而非环境噪声(非因果部分),可同时解决 “样本不足” 和 “关键子图识别不准确” 的问题。

2.2 研究问题

1. 核心问题

  如何设计一种数据增强驱动的图合理化框架,在小样本场景下,高效、准确地识别图的关键子图,同时提升 GNN 模型的预测性能、泛化能力与可解释性?

2. 具体拆解问题

  Q1:如何利用 “关键子图 - 环境子图” 的分离特性,设计针对性的数据增强策略,为图合理化生成有效虚拟样本?

  Q2:如何避免显式子图解码 / 编码的高复杂度,在 latent 空间中完成关键子图 - 环境子图的分离与表征学习?

  Q3:该框架在真实分子 / 聚合物数据集上,是否能同时优于现有图池化、泛化优化、图合理化方法(如 DIR、OOD-GNN)?

  Q4:框架识别的关键子图是否符合领域知识(如聚合物化学中功能基团与气体渗透性的关系),具备实际可解释性?

  Q5:框架对超参数(如关键子图大小、聚合函数)是否敏感,能否保持稳定性能?

3 方法

3.1 核心框架概述(Core Framework Overview)

框架定位

  GREA(Graph Rationalization with Environment-based Augmentations)是一种基于环境增强的图合理化框架,核心目标是在小样本场景下,高效分离图的 “关键子图(Rationale)” 与 “环境子图(Environment)”,同时通过数据增强提升模型的预测准确性、泛化能力与可解释性。

核心设计思路

  • 避免显式子图解码 / 编码:在 latent 空间 中完成关键子图 - 环境子图的分离、增强样本生成与表征学习,降低计算复杂度;

  • 双增强驱动训练:结合 “环境移除增强” 和 “环境替换增强”,利用环境子图的 “自然噪声” 特性,为关键子图识别提供多样化训练信号;

  • 交替优化策略:分别训练 “分离器()” 和 “预测器()”,平衡关键子图分离精度与属性预测性能。

整体流程

  1. 分离器( $GNN_1 + MLP_1$ ):对输入图进行节点级掩码预测,实现关键子图与环境子图的初步分离;

  2. 表征生成器( $GNN_2$ ):生成图的节点上下文表征,为子图表征计算提供基础;

  3. 增强样本生成:在 latent 空间中生成 “环境移除样本”(仅关键子图)和 “环境替换样本”(关键子图 + 其他图的环境子图);

  4. 预测器( $MLP_2$ ):基于两类增强样本联合训练,优化整体损失函数,输出最终预测结果。

3.2 关键子图 - 环境子图分离(Rationale-Environment Separation)

核心目标

  通过可学习的掩码机制,在 latent 空间中精准分离输入图的关键子图(因果部分)与环境子图(非因果部分),无需显式构建子图结构。

核心组件与计算逻辑

1. 分离器结构

  • 功能:生成节点级掩码向量 $m$,表示每个节点属于关键子图的概率;

  • 掩码计算:

      $m = \sigma\left(MLP_1(GNN_1(g))\right)$

    其中:

      • $GNN_1$ :编码器,生成节点的 latent 表征(捕捉节点特征与局部拓扑);

      • $MLP_1$ :解码器,将节点 latent 表征映射为一维概率值;

      • $\sigma$ :sigmoid 激活函数,确保掩码值  $m_v \in (0,1)$(即节点 $v$ 属于关键子图的概率);

      • 环境掩码: $1_N - m$ $1_N$  为 N 维全 1 列向量),表示节点属于环境子图的概率。

2. 节点上下文表征生成

  • 功能:生成具备全局上下文信息的节点表征 $H$,为子图表征计算提供支持;

  • 计算方式: $H = GNN_2(g)$ ,其中  $GNN_2$  可选用 GCN、GIN 等模型,与  $GNN_1$  独立(避免表征偏置)。

3. 子图表征计算

  • 核心逻辑:通过掩码与节点表征的加权求和,在 latent 空间中聚合关键子图与环境子图的表征;

  • 关键子图表征  $h^{(r)}$

     $h^{(r)} = 1_N^{\top} \cdot (m \times H)$

  • 环境子图表征  $h^{(e)}$

     $h^{(e)} = 1_N^{\top} \cdot \left( (1_N - m) \times H \right)$

    其中, $h^{(r)}, h^{(e)} \in \mathbb{R}^d$ d 为表征维度),Sum Pooling 确保聚合后的表征保留子图的全局信息。

设计优势

  • 避免显式子图构建:无需解码节点 / 边的实际连接关系,直接在 latent 空间完成分离,降低计算复杂度;

  • 概率化掩码:相比硬掩码(0/1),软掩码(概率值)更平滑,训练更稳定,且能捕捉节点对关键子图的贡献度差异。

3.3 基于环境的增强策略(Environment-based Augmentations)

核心前提

  • 批次训练设定:设训练批次中包含 $B$ 个图  $g_1, g_2, ..., g_B$ ,通过 3.2 节方法已得到每个图的关键子图表征  $h_i^{(r)}$  和环境子图表征  $h_i^{(e)}$ $i = 1,2,...,B$ );
  • 增强目标:利用环境子图的 “噪声特性”,生成多样化样本,迫使模型聚焦于关键子图(因果部分),提升关键子图识别的准确性与鲁棒性。

3.3.1 环境移除增强(Environment Removal Augmentation)

核心逻辑

  关键子图作为图属性的因果核心,仅用其表征应能实现与原图接近的预测性能。通过 “移除环境子图”,仅保留关键子图用于训练,强化关键子图的预测能力。

预测计算

  给定图  $g_i$  的关键子图表征  $h_i^{(r)}$ ,预测器输出:

     $\hat{y}_i^{(r)} = MLP_2\left(h_i^{(r)}\right)$

  其中 $MLP_2$ 为属性预测器的解码器,与 3.1 节图属性预测器的 MLP 结构一致。

3.3.2 环境替换增强(Environment Replacement Augmentation)

核心逻辑

  环境子图是无关噪声,若将图  $g_i$  的关键子图与其他图  $g_j$ $j \neq i$ )的环境子图组合,生成的虚拟样本应与  $g_i$  具有相同标签(因关键子图未变)。通过这种替换,模型能学习到 “关键子图不变则标签不变” 的因果规律,忽略环境噪声干扰。

虚拟样本生成与预测

  1. 表征聚合:通过聚合函数  $AGG(\cdot, \cdot)$  组合  $h_i^{(r)}$ $g_i$  的关键子图表征)与  $h_j^{(e)}$ $g_j$  的环境子图表征),得到虚拟样本的表征  $h_{(i,j)}$

    • 聚合函数可选:求和池化(Sum Pooling,默认)、平均池化(Mean Pooling)、最大池化(Max Pooling)、拼接(Concatenation)等,公式示例(Sum Pooling):

      $h_{(i,j)} = AGG\left(h_i^{(r)}, h_j^{(e)}\right) = h_i^{(r)} + h_j^{(e)}$

  2. 预测计算:虚拟样本的预测标签应与  $g_i$  的真实标签  $y_i$  一致,预测公式:

       $\hat{y}_{(i,j)} = MLP_2\left(h_{(i,j)}\right)$

  3. 样本数量:每个图  $g_i$  可与批次中其他  $B-1$  个图的环境子图组合,生成  $B-1$  个虚拟样本,显著提升训练数据多样性。

增强策略的优势

  • 针对性强:专为图合理化设计,直接关联关键子图与环境子图的分离逻辑;

  • 无额外标注成本:虚拟样本的标签由原关键子图的标签继承,无需人工标注;

  • 兼容性高:在 latent 空间中实现,不依赖图的具体结构,适用于分子、聚合物等各类图数据。

3.4 优化目标与训练策略(Optimization & Training Schema)

3.4.1 损失函数设计

  损失函数分为三类,分别对应增强样本预测、关键子图大小正则化,确保模型兼顾预测精度与关键子图合理性。

1. 环境移除损失( $\mathcal{L}_{rem}$

  • 目标:优化 “环境移除样本” 的预测精度,确保关键子图具备独立预测能力;

  • 公式(以二分类任务为例,采用交叉熵损失):

    $\mathcal{L}_{rem} = y_i \cdot \log \hat{y}_i^{(r)} + (1 - y_i) \cdot \log (1 - \hat{y}_i^{(r)})$

  • 回归任务适配:若为连续属性预测(如聚合物密度),则替换为均方误差(MSE)损失。

2. 环境替换损失( $\mathcal{L}_{rep}$

  • 目标:优化 “环境替换样本” 的预测精度,确保模型忽略环境噪声,聚焦关键子图;

  • 公式(以二分类任务为例):

    $\mathcal{L}_{rep} = \frac{1}{B} \sum_{j=1}^{B} \left( y_i \cdot \log \hat{y}_{(i,j)} + (1 - y_i) \cdot \log (1 - \hat{y}_{(i,j)}) \right)$

    其中  $\frac{1}{B}$  为批次内虚拟样本的损失均值,平衡不同虚拟样本的贡献。

3. 正则化损失( $\mathcal{L}_{reg}$

  • 目标:控制关键子图的大小,避免关键子图过大(接近原图)或过小(丢失核心信息);

  • 公式:

    $\mathcal{L}_{reg} = \left| \frac{1_N^{\top} \cdot m}{N} - \gamma \right|$

    其中:

    • $\frac{1_N^{\top} \cdot m}{N}$ :关键子图的平均节点占比(掩码均值);

    • $\gamma \in [0,1]$ :超参数,控制关键子图的预期大小(如  $\gamma = 0.3$  表示关键子图约占原图 30% 的节点);

    • 损失意义:当关键子图实际大小与预期  $\gamma$  偏离时,产生惩罚,确保关键子图的紧凑性与完整性。

3.4.2 整体损失与交替训练

1. 整体损失函数

  • 预测器损失(训练  $GNN_2 + MLP_2$ ):

     $\mathcal{L}_{pred} = \mathcal{L}_{rem} + \alpha \cdot \mathcal{L}_{rep}$  

    其中,$\alpha$  为超参数,控制环境替换损失的权重;

  • 分离器损失(训练  $GNN_1 + MLP_1$ ):

     $\mathcal{L}_{sep} = \mathcal{L}_{rem} + \alpha \cdot \mathcal{L}_{rep} + \beta \cdot \mathcal{L}_{reg}$

    其中,$\beta$  为超参数,控制正则化损失的权重。

2. 交替训练策略

  • 核心逻辑:分离器与预测器存在相互依赖(分离器的掩码质量影响预测器性能,预测器的损失反馈优化分离器),因此采用交替训练:

    1. 固定分离器  $f_{sep}$ ,训练预测器  $f_{pred}$  共  $T_{pred}$  个 epoch;

    2. 固定预测器  $f_{pred}$ ,训练分离器  $f_{sep}$  共  $T_{sep}$  个 epoch;

    3. 重复上述步骤,直至模型收敛。

  • 超参数设置: $T_{sep} \in \{1,2\}$ $T_{pred} \in \{2,3\}$ (通过验证集调优),确保训练稳定且高效。

3.4.3 推理阶段逻辑

  • 推理时,仅需通过分离器  $f_{sep}$  得到输入图的关键子图表征  $h^{(r)}$ ,代入预测器  $MLP_2$  输出最终预测结果:

    $\hat{y} = MLP_2\left(h^{(r)}\right)$

  • 关键子图可视化:通过掩码 $m$ 筛选概率高于阈值(如 0.5)的节点,构建关键子图结构,用于可解释性分析。

 

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询