赤峰市网站建设_网站建设公司_模板建站_seo优化
2025/12/29 7:31:03 网站建设 项目流程

文章目录

  • Week 31: 深度学习补遗:Mamba
    • 摘要
    • Abstract
    • 1. 连续微分到离散递归
      • 1.1 连续系统的物理意义
      • 1.2 零阶保持 (ZOH)
    • 2. 选择性扫描 (Selective Scan)
      • 2.1 传统LTI系统
      • 2.2 Mamba 的动态门控机制
    • 3. 代码实现
    • 4. RNN、Transformer和Mamba的优劣
    • 总结

Week 31: 深度学习补遗:Mamba

摘要

本周的研究了2024年的热门深度学习架构Mamba。从底层的微分方程离散化 (Discretization)入手,理解了 Mamba 的核心创新——选择性扫描 (Selective Scan),解释了模型是如何通过动态调整时间步长Δ \DeltaΔ来实现类似于 LSTM 门控的“遗忘”与“记忆”机制。

Abstract

This week’s research examined Mamba, a prominent deep learning architecture for 2024. Beginning with the underlying discretisation of differential equations, we explored Mamba’s core innovation—Selective Scan—and elucidated how the model achieves LSTM-like gating mechanisms for “forgetting” and “remembering” by dynamically adjusting the time step sizeΔ \DeltaΔ.

1. 连续微分到离散递归

1.1 连续系统的物理意义

SSM 的起点是描述物理系统随时间变化的微分方程。
h ′ ( t ) = A h ( t ) + B x ( t ) h'(t) = \mathbf{A}h(t) + \mathbf{B}x(t)h(t)=Ah(t)+Bx(t)
h ( t ) h(t)h(t)是系统的“当前状态”(例如弹簧的位置和速度)。A \mathbf{A}A矩阵描述系统在没有外力时的自然演变(例如弹簧的阻尼衰减)。B x ( t ) \mathbf{B}x(t)Bx(t)是外力输入对状态的影响。

1.2 零阶保持 (ZOH)

在深度学习中,数据不是连续的,而是离散的采样点x 0 , x 1 , … , x k x_0, x_1, \dots, x_kx0,x1,,xk。我们需要将上述微分方程转化为递归式:h t = A ‾ h t − 1 + B ‾ x t h_t = \overline{\mathbf{A}} h_{t-1} + \overline{\mathbf{B}} x_tht=Aht1+Bxt

假设在时间区间[ t , t + Δ ] [t, t+\Delta][t,t+Δ]内,输入x ( t ) x(t)x(t)保持恒定值x k x_kxk(即“零阶保持”)。

对微分方程两边积分,解得t + Δ t+\Deltat+Δ时刻的状态:
h ( t + Δ ) = e Δ A h ( t ) + ∫ t t + Δ e ( t + Δ − τ ) A B x ( τ ) d τ h(t+\Delta) = e^{\Delta \mathbf{A}} h(t) + \int_{t}^{t+\Delta} e^{(t+\Delta-\tau)\mathbf{A}} \mathbf{B} x(\tau) d\tauh(t+Δ)=eΔAh(t)+tt+Δe(t+Δτ)ABx(τ)dτ

由于假设x ( τ ) x(\tau)x(τ)在该区间是常数,可以提到积分号外面。积分部分∫ 0 Δ e u A d u = A − 1 ( e Δ A − I ) \int_{0}^{\Delta} e^{u\mathbf{A}} du = \mathbf{A}^{-1}(e^{\Delta \mathbf{A}} - \mathbf{I})0ΔeuAdu=A1(eΔAI)

由此得到具体的离散化参数公式:

  1. 状态转移矩阵A ‾ \overline{\mathbf{A}}A
    A ‾ = exp ⁡ ( Δ ⋅ A ) \overline{\mathbf{A}} = \exp(\Delta \cdot \mathbf{A})A=exp(ΔA)

    • 具体含义Δ \DeltaΔ越大,A ‾ \overline{\mathbf{A}}A变化越剧烈。如果A \mathbf{A}A是负的(通常为了稳定性设为对角负矩阵),exp ⁡ ( Δ A ) \exp(\Delta \mathbf{A})exp(ΔA)就会趋近于 0。
    • 时间步长Δ \DeltaΔ越大,系统遗忘历史信息越快。
  2. 输入投影矩阵B ‾ \overline{\mathbf{B}}B
    B ‾ = ( Δ A ) − 1 ( exp ⁡ ( Δ A ) − I ) ⋅ Δ B \overline{\mathbf{B}} = (\Delta \mathbf{A})^{-1} (\exp(\Delta \mathbf{A}) - \mathbf{I}) \cdot \Delta \mathbf{B}B=(ΔA)1(exp(ΔA)I)ΔB

    • 近似:在Δ \DeltaΔ很小时,泰勒展开第一项主导,可近似为B ‾ ≈ Δ ⋅ B \overline{\mathbf{B}} \approx \Delta \cdot \mathbf{B}BΔB
    • 时间步长Δ \DeltaΔ越大,当前输入x t x_txt对状态的影响权重越大。

2. 选择性扫描 (Selective Scan)

2.1 传统LTI系统

在 Mamba 之前的 S4 模型中,Δ , A , B \Delta, \mathbf{A}, \mathbf{B}Δ,A,B都是静态参数(训练完就固定了)。
这相当于模型用同一套滤波器处理所有数据。

假设输入序列是[有用信息, 噪音, 噪音, 有用信息]。LTI 系统无法在遇到“噪音”时主动切断记忆,也无法在遇到“有用信息”时以此为重。

2.2 Mamba 的动态门控机制

Mamba 将参数变成了输入的函数:
( Δ t , B t , C t ) = Linear ( x t ) (\Delta_t, \mathbf{B}_t, \mathbf{C}_t) = \text{Linear}(x_t)(Δt,Bt,Ct)=Linear(xt)

这意味着对于序列中的每一个 Tokenx t x_txt,模型都会生成一套独一无二的离散化参数。

  1. 需要忽略的噪声

    • 模型检测到x t x_txt是噪声。
    • 预测出的Δ t \Delta_tΔt变大(例如从 0.1 变为 10)。
    • 结果:A ‾ t = exp ⁡ ( − 10 ) ≈ 0 \overline{\mathbf{A}}_t = \exp(-10) \approx 0At=exp(10)0
    • 效果h t = 0 ⋅ h t − 1 + … h_t = 0 \cdot h_{t-1} + \dotsht=0ht1+,历史状态被清空/遗忘,噪声没有被长期记忆。
  2. 需要保留的关键信号

    • 模型检测到x t x_txt是关键信号。
    • 预测出的Δ t \Delta_tΔt变小,且B t \mathbf{B}_tBt变大。
    • 结果:A ‾ t ≈ 1 \overline{\mathbf{A}}_t \approx 1At1B ‾ t \overline{\mathbf{B}}_tBt很大。
    • 效果:历史记忆无损保留,当前输入被强力写入状态。

这就是 Mamba 被称为“具有选择性”的具体原因——它通过动态调整时间刻度Δ \DeltaΔ,实现了类似于 LSTM 中“遗忘门”和“输入门”的功能,但计算效率更高。

3. 代码实现

importtorchimporttorch.nnasnnimporttorch.nn.functionalasFclassMambaBlockSpecific(nn.Module):def__init__(self,d_model,d_state=16):super().__init__()self.d_model=d_model# D: 输入维度 (例如 512)self.d_state=d_state# N: 状态维度 (例如 16)# 定义 A, D 参数self.A_log=nn.Parameter(torch.log(torch.randn(d_model,d_state).abs()))self.D=nn.Parameter(torch.ones(d_model))# 核心投影层self.in_proj=nn.Linear(d_model,d_model*2)# x_proj 负责从输入生成 B, C, Delta# 输出维度是 N + N + 1 (B的状态, C的状态, Delta的标量)self.x_proj=nn.Linear(d_model,(d_state*2)+1)self.dt_proj=nn.Linear(1,d_model)defforward(self,x):# x Shape: [Batch, Seq_Len, D]B,L,D=x.shape N=self.d_state# 1. 扩展输入维度x_and_res=self.in_proj(x)# [B, L, 2*D]x_in,res=x_and_res.chunk(2,dim=-1)# x_in: [B, L, D], res: [B, L, D]# ----------------------------------------------------# 2. 动态参数生成 (Discretization & Selection)# ----------------------------------------------------# 这一步是 Mamba 区别于 S4 的关键:参数随 t 变化ssm_params=self.x_proj(x_in)# [B, L, 2*N + 1]# 切分出 B_t, C_t, dt_tB_t=ssm_params[:,:,:N]# [B, L, N]C_t=ssm_params[:,:,N:2*N]# [B, L, N]dt_t=ssm_params[:,:,2*N:]# [B, L, 1]# 广播 Delta: 将标量 Delta 投影回特征维度 Ddt_t=F.softplus(self.dt_proj(dt_t))# [B, L, D]# 计算离散化的 A_bar (Decay Rate)# A 是 (D, N), dt_t 是 (B, L, D) -> 广播计算# exp(Delta * A) 决定了记忆衰减的速率A=-torch.exp(self.A_log)dA=torch.exp(torch.einsum('bld,dn->bldn',dt_t,A))# [B, L, D, N]# 计算离散化的 B_bar (Input Weight)# B_bar = Delta * BdB=torch.einsum('bld,bln->bldn',dt_t,B_t)# [B, L, D, N]# ----------------------------------------------------# 3. 状态空间扫描 (SSM Scan)# ----------------------------------------------------# 初始化隐状态 h: [B, D, N]# 注意:这里的 h 是 latent state,相比 Transformer 的 KV Cache 极小h=torch.zeros(B,D,N,device=x.device)ys=[]# 串行扫描演示 (实际 CUDA kernel 会使用并行前缀和优化)fortinrange(L):# h_t = A_bar_t * h_{t-1} + B_bar_t * x_t# 具体操作:# 1. 衰减历史: h * dA[:, t] (按元素乘,不同通道衰减率不同)# 2. 写入新值: x_in[:, t] * dB[:, t]h=dA[:,t]*h+dB[:,t]*x_in[:,t].unsqueeze(-1)# y_t = C_t * h_t# 将隐状态 N 投影回输出维度 Dy_t=torch.einsum('bdn,bn->bd',h,C_t[:,t])ys.append(y_t)y=torch.stack(ys,dim=1)# [B, L, D]# 残差连接与输出y=y+x_in*self.Dreturny*F.silu(res)

4. RNN、Transformer和Mamba的优劣

相比于 RNN (LSTM/GRU),Mamba 克服了其门控机制虽动态但无法利用 GPU 并行训练的瓶颈。

Mamba 虽然在逻辑上保持了递归形式(h t = A ‾ t h t − 1 + … h_t = \overline{\mathbf{A}}_t h_{t-1} + \dotsht=Atht1+),但由于采用了不包含tanh ⁡ \tanhtanh等非线性激活的线性递归,从而可以利用数学上的结合律引入并行前缀扫描算法,实现了像 Transformer 一样极速的并行训练。

而在与 Transformer 的对比中,Mamba 解决了 Attention 矩阵需存储L × L L \times LL×L历史交互以及推理时 KV Cache 显存占用巨大的问题。Mamba 在推理时仅需维护一个大小为D × N D \times ND×N的固定状态h t h_tht,这意味着无论序列长度是 1k 还是 100k,其推理所需的显存和计算量始终保持恒定

总结

本周学习了Mamba,不同于 Transformer 的注意力机制,Mamba 是状态空间模型 (SSM)*的集大成者。通过具体推导了零阶保持 (ZOH) 技术的数学细节,我理解了选择性扫描 (Selective Scan)是如何实现类似门控的效果的。而代码复现部分,我更重点聚焦了动态参数生成的具体维度变化与数据流向。Mamba有其局限性,但其模块具有比较显著的创新性,可以考虑后续融合一部分模块在未来研究中。

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询