- 1. Transformer 的基本结构回顾
- 2. 前向传播在 Transformer 中的执行
- 3. 反向传播在 Transformer 中的执行
- 4. 哪些层执行前向/反向传播
- 图例
- Transformer Block 文本示意
- 开启 Gradient Checkpoint 后(节省显存)
好的,我们把问题缩小到 Transformer 架构 里,来看前向传播和反向传播是如何执行的。为了清晰,我会分层次讲,从整体到细节。
1. Transformer 的基本结构回顾
一个标准 Transformer 的 Encoder 或 Decoder Block,通常包含以下几层(按顺序):
-
Multi-Head Attention (MHA)
- 输入:上一层输出(或 embedding)
- 输出:经过注意力加权的表示
-
Add & LayerNorm (残差 + 归一化)
-
Feed Forward Network (FFN)
- 两层线性层 + 激活(通常是 GELU 或 ReLU)
-
Add & LayerNorm (残差 + 归一化)
所以一个 Block 可以画作:
Input → MHA → Add&Norm → FFN → Add&Norm → Output
2. 前向传播在 Transformer 中的执行
原则:从输入到输出,一层层执行前向传播。
-
Embedding 层:
- 输入 token → embedding 向量 + positional encoding
-
每个 Transformer Block:
-
Multi-Head Attention:
- 计算 Query/Key/Value → 注意力权重 → 输出
- 保存输出(激活)用于反向传播
-
Add & LayerNorm:
- 残差连接 + 归一化
- 保存归一化输出
-
FFN:
- 线性 → 激活 → 线性
- 保存 FFN 输出
-
Add & LayerNorm
- 最终输出作为下一层输入
-
-
最后输出层(比如 LM head):
- 输出 logits → 预测概率
总结:前向传播就是从 Embedding 到最后输出,每层都计算并存激活。
3. 反向传播在 Transformer 中的执行
原则:从输出层开始,沿着网络倒着计算梯度。
-
输出层:
- 计算 loss 对 logits 的梯度
-
最后一个 Block:
- Add&Norm → 梯度向 FFN 和残差路径传播
- FFN → 梯度回传到线性层权重
- Add&Norm → 梯度回传到 MHA 输出
- MHA → 梯度回传到 Query/Key/Value 权重
-
前一层 Block:
- 重复相同流程
-
Embedding 层:
- 计算 embedding 权重的梯度
注意点:
-
梯度计算需要前向的激活:
- FFN 层的输入
- MHA 的 Q/K/V
- LayerNorm 的输入
-
如果用了 gradient checkpoint:
- 不保存某些中间激活
- 反向传播时需要重新计算前向来得到激活,再算梯度
4. 哪些层执行前向/反向传播
| 层类型 | 前向传播 | 反向传播 |
|---|---|---|
| Embedding | X → embedding | 梯度回传到 embedding 权重 |
| MHA (Attention) | Q/K/V → attention → 输出 | 梯度回传到 Q/K/V 权重,算梯度时可能需要前向激活 |
| FFN | Linear → Activation → Linear | 梯度回传到两层线性层权重,Activation 需要前向值 |
| LayerNorm | 归一化 | 梯度回传到输入和归一化参数 |
| Output / LM head | Linear → LogSoftmax | 梯度回传到最后线性层 |
总结一句话:
Transformer 每一层的前向传播先算输出,反向传播在梯度链上倒着执行,每层都用前向激活来算梯度。
如果你想,我可以画一张Transformer Block 前向和反向传播的示意图,标出激活保存和 gradient checkpoint 的位置,让你直观理解“什么时候重新算一遍”。
图例
好的,我用文本示意图来展示一个 Transformer Block 的前向传播、反向传播,以及 gradient checkpoint 机制下的“重新算一遍”。
Transformer Block 文本示意
前向传播 (Forward Pass)
------------------------
Input│▼
[Multi-Head Attention] ← 前向输出保存(activation)│▼
[Add & LayerNorm] ← 前向输出保存(activation)│▼
[Feed Forward Network]│ ├─ Linear1│ ├─ Activation│ └─ Linear2 ← 前向输出保存▼
[Add & LayerNorm] ← 前向输出保存│▼
Output (送入下一个 Block)
反向传播 (Backward Pass)
------------------------
Output gradient│▼
[Add & LayerNorm] ← 使用前向激活计算梯度│▼
[Feed Forward Network]│ ├─ Linear2 gradient│ ├─ Activation gradient│ └─ Linear1 gradient▼
[Add & LayerNorm]│▼
[Multi-Head Attention]│ ├─ Output gradient│ └─ Q/K/V gradient▼
Input gradient
开启 Gradient Checkpoint 后(节省显存)
- 假设只在
[Add&LayerNorm]层存 checkpoint,FFN 和 MHA 不存中间激活 - 反向传播时:
Input gradient│▼
[Multi-Head Attention] ← 没存激活 → 重新执行前向计算得到中间激活,再算梯度│▼
[Add & LayerNorm] ← checkpoint 激活存在,直接算梯度│▼
[Feed Forward Network] ← 没存激活 → 重新执行前向计算得到激活,再算梯度│▼
[Add & LayerNorm] ← checkpoint 激活存在│▼
Output gradient
关键点:
- 没存激活的层 → 反向传播时需要重新算一次前向传播
- checkpoint 层 → 激活直接可用,不需要重新算
- 这样就用计算时间换显存空间