文章目录
- Week 31: 深度学习补遗:Mamba
- 摘要
- Abstract
- 1. 连续微分到离散递归
- 1.1 连续系统的物理意义
- 1.2 零阶保持 (ZOH)
- 2. 选择性扫描 (Selective Scan)
- 2.1 传统LTI系统
- 2.2 Mamba 的动态门控机制
- 3. 代码实现
- 4. RNN、Transformer和Mamba的优劣
- 总结
Week 31: 深度学习补遗:Mamba
摘要
本周的研究了2024年的热门深度学习架构Mamba。从底层的微分方程离散化 (Discretization)入手,理解了 Mamba 的核心创新——选择性扫描 (Selective Scan),解释了模型是如何通过动态调整时间步长Δ \DeltaΔ来实现类似于 LSTM 门控的“遗忘”与“记忆”机制。
Abstract
This week’s research examined Mamba, a prominent deep learning architecture for 2024. Beginning with the underlying discretisation of differential equations, we explored Mamba’s core innovation—Selective Scan—and elucidated how the model achieves LSTM-like gating mechanisms for “forgetting” and “remembering” by dynamically adjusting the time step sizeΔ \DeltaΔ.
1. 连续微分到离散递归
1.1 连续系统的物理意义
SSM 的起点是描述物理系统随时间变化的微分方程。
h ′ ( t ) = A h ( t ) + B x ( t ) h'(t) = \mathbf{A}h(t) + \mathbf{B}x(t)h′(t)=Ah(t)+Bx(t)
h ( t ) h(t)h(t)是系统的“当前状态”(例如弹簧的位置和速度)。A \mathbf{A}A矩阵描述系统在没有外力时的自然演变(例如弹簧的阻尼衰减)。B x ( t ) \mathbf{B}x(t)Bx(t)是外力输入对状态的影响。
1.2 零阶保持 (ZOH)
在深度学习中,数据不是连续的,而是离散的采样点x 0 , x 1 , … , x k x_0, x_1, \dots, x_kx0,x1,…,xk。我们需要将上述微分方程转化为递归式:h t = A ‾ h t − 1 + B ‾ x t h_t = \overline{\mathbf{A}} h_{t-1} + \overline{\mathbf{B}} x_tht=Aht−1+Bxt。
假设在时间区间[ t , t + Δ ] [t, t+\Delta][t,t+Δ]内,输入x ( t ) x(t)x(t)保持恒定值x k x_kxk(即“零阶保持”)。
对微分方程两边积分,解得t + Δ t+\Deltat+Δ时刻的状态:
h ( t + Δ ) = e Δ A h ( t ) + ∫ t t + Δ e ( t + Δ − τ ) A B x ( τ ) d τ h(t+\Delta) = e^{\Delta \mathbf{A}} h(t) + \int_{t}^{t+\Delta} e^{(t+\Delta-\tau)\mathbf{A}} \mathbf{B} x(\tau) d\tauh(t+Δ)=eΔAh(t)+∫tt+Δe(t+Δ−τ)ABx(τ)dτ
由于假设x ( τ ) x(\tau)x(τ)在该区间是常数,可以提到积分号外面。积分部分∫ 0 Δ e u A d u = A − 1 ( e Δ A − I ) \int_{0}^{\Delta} e^{u\mathbf{A}} du = \mathbf{A}^{-1}(e^{\Delta \mathbf{A}} - \mathbf{I})∫0ΔeuAdu=A−1(eΔA−I)。
由此得到具体的离散化参数公式:
状态转移矩阵A ‾ \overline{\mathbf{A}}A:
A ‾ = exp ( Δ ⋅ A ) \overline{\mathbf{A}} = \exp(\Delta \cdot \mathbf{A})A=exp(Δ⋅A)- 具体含义:Δ \DeltaΔ越大,A ‾ \overline{\mathbf{A}}A变化越剧烈。如果A \mathbf{A}A是负的(通常为了稳定性设为对角负矩阵),exp ( Δ A ) \exp(\Delta \mathbf{A})exp(ΔA)就会趋近于 0。
- 时间步长Δ \DeltaΔ越大,系统遗忘历史信息越快。
输入投影矩阵B ‾ \overline{\mathbf{B}}B:
B ‾ = ( Δ A ) − 1 ( exp ( Δ A ) − I ) ⋅ Δ B \overline{\mathbf{B}} = (\Delta \mathbf{A})^{-1} (\exp(\Delta \mathbf{A}) - \mathbf{I}) \cdot \Delta \mathbf{B}B=(ΔA)−1(exp(ΔA)−I)⋅ΔB- 近似:在Δ \DeltaΔ很小时,泰勒展开第一项主导,可近似为B ‾ ≈ Δ ⋅ B \overline{\mathbf{B}} \approx \Delta \cdot \mathbf{B}B≈Δ⋅B。
- 时间步长Δ \DeltaΔ越大,当前输入x t x_txt对状态的影响权重越大。
2. 选择性扫描 (Selective Scan)
2.1 传统LTI系统
在 Mamba 之前的 S4 模型中,Δ , A , B \Delta, \mathbf{A}, \mathbf{B}Δ,A,B都是静态参数(训练完就固定了)。
这相当于模型用同一套滤波器处理所有数据。
假设输入序列是[有用信息, 噪音, 噪音, 有用信息]。LTI 系统无法在遇到“噪音”时主动切断记忆,也无法在遇到“有用信息”时以此为重。
2.2 Mamba 的动态门控机制
Mamba 将参数变成了输入的函数:
( Δ t , B t , C t ) = Linear ( x t ) (\Delta_t, \mathbf{B}_t, \mathbf{C}_t) = \text{Linear}(x_t)(Δt,Bt,Ct)=Linear(xt)
这意味着对于序列中的每一个 Tokenx t x_txt,模型都会生成一套独一无二的离散化参数。
需要忽略的噪声
- 模型检测到x t x_txt是噪声。
- 预测出的Δ t \Delta_tΔt变大(例如从 0.1 变为 10)。
- 结果:A ‾ t = exp ( − 10 ) ≈ 0 \overline{\mathbf{A}}_t = \exp(-10) \approx 0At=exp(−10)≈0。
- 效果:h t = 0 ⋅ h t − 1 + … h_t = 0 \cdot h_{t-1} + \dotsht=0⋅ht−1+…,历史状态被清空/遗忘,噪声没有被长期记忆。
需要保留的关键信号
- 模型检测到x t x_txt是关键信号。
- 预测出的Δ t \Delta_tΔt变小,且B t \mathbf{B}_tBt变大。
- 结果:A ‾ t ≈ 1 \overline{\mathbf{A}}_t \approx 1At≈1,B ‾ t \overline{\mathbf{B}}_tBt很大。
- 效果:历史记忆无损保留,当前输入被强力写入状态。
这就是 Mamba 被称为“具有选择性”的具体原因——它通过动态调整时间刻度Δ \DeltaΔ,实现了类似于 LSTM 中“遗忘门”和“输入门”的功能,但计算效率更高。
3. 代码实现
importtorchimporttorch.nnasnnimporttorch.nn.functionalasFclassMambaBlockSpecific(nn.Module):def__init__(self,d_model,d_state=16):super().__init__()self.d_model=d_model# D: 输入维度 (例如 512)self.d_state=d_state# N: 状态维度 (例如 16)# 定义 A, D 参数self.A_log=nn.Parameter(torch.log(torch.randn(d_model,d_state).abs()))self.D=nn.Parameter(torch.ones(d_model))# 核心投影层self.in_proj=nn.Linear(d_model,d_model*2)# x_proj 负责从输入生成 B, C, Delta# 输出维度是 N + N + 1 (B的状态, C的状态, Delta的标量)self.x_proj=nn.Linear(d_model,(d_state*2)+1)self.dt_proj=nn.Linear(1,d_model)defforward(self,x):# x Shape: [Batch, Seq_Len, D]B,L,D=x.shape N=self.d_state# 1. 扩展输入维度x_and_res=self.in_proj(x)# [B, L, 2*D]x_in,res=x_and_res.chunk(2,dim=-1)# x_in: [B, L, D], res: [B, L, D]# ----------------------------------------------------# 2. 动态参数生成 (Discretization & Selection)# ----------------------------------------------------# 这一步是 Mamba 区别于 S4 的关键:参数随 t 变化ssm_params=self.x_proj(x_in)# [B, L, 2*N + 1]# 切分出 B_t, C_t, dt_tB_t=ssm_params[:,:,:N]# [B, L, N]C_t=ssm_params[:,:,N:2*N]# [B, L, N]dt_t=ssm_params[:,:,2*N:]# [B, L, 1]# 广播 Delta: 将标量 Delta 投影回特征维度 Ddt_t=F.softplus(self.dt_proj(dt_t))# [B, L, D]# 计算离散化的 A_bar (Decay Rate)# A 是 (D, N), dt_t 是 (B, L, D) -> 广播计算# exp(Delta * A) 决定了记忆衰减的速率A=-torch.exp(self.A_log)dA=torch.exp(torch.einsum('bld,dn->bldn',dt_t,A))# [B, L, D, N]# 计算离散化的 B_bar (Input Weight)# B_bar = Delta * BdB=torch.einsum('bld,bln->bldn',dt_t,B_t)# [B, L, D, N]# ----------------------------------------------------# 3. 状态空间扫描 (SSM Scan)# ----------------------------------------------------# 初始化隐状态 h: [B, D, N]# 注意:这里的 h 是 latent state,相比 Transformer 的 KV Cache 极小h=torch.zeros(B,D,N,device=x.device)ys=[]# 串行扫描演示 (实际 CUDA kernel 会使用并行前缀和优化)fortinrange(L):# h_t = A_bar_t * h_{t-1} + B_bar_t * x_t# 具体操作:# 1. 衰减历史: h * dA[:, t] (按元素乘,不同通道衰减率不同)# 2. 写入新值: x_in[:, t] * dB[:, t]h=dA[:,t]*h+dB[:,t]*x_in[:,t].unsqueeze(-1)# y_t = C_t * h_t# 将隐状态 N 投影回输出维度 Dy_t=torch.einsum('bdn,bn->bd',h,C_t[:,t])ys.append(y_t)y=torch.stack(ys,dim=1)# [B, L, D]# 残差连接与输出y=y+x_in*self.Dreturny*F.silu(res)4. RNN、Transformer和Mamba的优劣
相比于 RNN (LSTM/GRU),Mamba 克服了其门控机制虽动态但无法利用 GPU 并行训练的瓶颈。
Mamba 虽然在逻辑上保持了递归形式(h t = A ‾ t h t − 1 + … h_t = \overline{\mathbf{A}}_t h_{t-1} + \dotsht=Atht−1+…),但由于采用了不包含tanh \tanhtanh等非线性激活的线性递归,从而可以利用数学上的结合律引入并行前缀扫描算法,实现了像 Transformer 一样极速的并行训练。
而在与 Transformer 的对比中,Mamba 解决了 Attention 矩阵需存储L × L L \times LL×L历史交互以及推理时 KV Cache 显存占用巨大的问题。Mamba 在推理时仅需维护一个大小为D × N D \times ND×N的固定状态h t h_tht,这意味着无论序列长度是 1k 还是 100k,其推理所需的显存和计算量始终保持恒定。
总结
本周学习了Mamba,不同于 Transformer 的注意力机制,Mamba 是状态空间模型 (SSM)*的集大成者。通过具体推导了零阶保持 (ZOH) 技术的数学细节,我理解了选择性扫描 (Selective Scan)是如何实现类似门控的效果的。而代码复现部分,我更重点聚焦了动态参数生成的具体维度变化与数据流向。Mamba有其局限性,但其模块具有比较显著的创新性,可以考虑后续融合一部分模块在未来研究中。