Mamba

张开发
2026/4/4 11:02:21 15 分钟阅读
Mamba
MambaMiniAt,Bt,Ct的计算都是并行的在官方mamba中SelectiveScan算法可以实现logN的复杂度但是这里为了简化SelectiveScan设计为了类似RNN的N的复杂度串行方式方便理解。OverViewSSMSelectiveScanMamba并行扫描LogN推导引言问题与目标问题Mamba的核心状态更新是一个看似无法并行的时变递归过程h_t A_bar_t * h_{t-1} B_bar_t * x_t其中A_bar_t和B_bar_t在每个时间步t都依赖于当前输入x_t。目标找到一种方法可以并行地计算出整个隐藏状态序列H [h_0, h_1, ..., h_{N-1}]以利用现代硬件如GPU的并行计算能力从而加速模型训练。核心思想将这个复杂的递归过程抽象成一个满足结合律的代数结构然后利用经典的并行扫描算法来高效求解。第一步定义操作单元T我们需要将每个时间步t的变换操作封装成一个独立的单元。预处理输入: 为了简化后续操作我们将B_bar_t和x_t预先乘在一起形成一个新的项B_t(B prime t)。B_t B_bar_t * x_t这个操作可以对所有时间步并行执行。简化递归: 现在的递归公式变为h_t A_bar_t * h_{t-1} B_t这个公式的结构变得更清晰新状态 变换矩阵 * 旧状态 新增量。定义操作单元T_t: 我们将第t步的完整变换操作定义为一个包含两个元素的“对”pairT_t (A_bar_t, B_t)A_bar_t: 代表对历史状态的“变换”部分。B_t: 代表在当前步“注入”的新信息部分。第二步定义成对操作⊕(Combine Operation)现在我们需要定义一个操作⊕它能将两个连续时间步的变换T_i和T_j其中ji1合并成一个等效的、跨越这两步的单一变换。推导:应用T_ih_i A_bar_i * h_{i-1} B_i接着应用T_jh_j A_bar_j * h_i B_j将h_i的表达式代入h_j A_bar_j * (A_bar_i * h_{i-1} B_i) B_j展开并重新整理h_j (A_bar_j * A_bar_i) * h_{i-1} (A_bar_j * B_i B_j)观察结构: 这个最终的表达式h_j (新A) * h_{i-1} (新B)告诉我们连续应用T_i和T_j的效果等同于应用一个新的、组合后的变换T_k (A_bar_k, B_k)其中A_bar_k A_bar_j * A_bar_iB_k A_bar_j * B_i B_j定义成对操作⊕: 基于此我们正式定义⊕操作T_j ⊕ T_i (A_bar_j * A_bar_i, A_bar_j * B_i B_j)第三步证明结合律 (Associativity)结合律是使用并行扫描算法的充要条件。我们需要证明(T_k ⊕ T_j) ⊕ T_i T_k ⊕ (T_j ⊕ T_i)。计算左侧(T_k ⊕ T_j) ⊕ T_i:先算T_k ⊕ T_j (A_k*A_j, A_k*B_j B_k)再与T_i组合(A_k*A_j, A_k*B_j B_k) ⊕ (A_i, B_i)新A部分(A_k*A_j) * A_i A_k*A_j*A_i新B’部分(A_k*A_j) * B_i (A_k*B_j B_k) A_k*A_j*B_i A_k*B_j B_k计算右侧T_k ⊕ (T_j ⊕ T_i):先算T_j ⊕ T_i (A_j*A_i, A_j*B_i B_j)再与T_k组合(A_k, B_k) ⊕ (A_j*A_i, A_j*B_i B_j)新A部分A_k * (A_j*A_i) A_k*A_j*A_i新B’部分A_k * (A_j*B_i B_j) B_k A_k*A_j*B_i A_k*B_j B_k结论: 左右两侧的最终结果完全相同。结合律成立。第四步并行扫描算法原理与S和h的关系现在我们可以应用并行扫描了。并行扫描在这里具体指前缀和 (Prefix Sum)的目的是计算出序列S [S_0, S_1, S_2, ...]其中S_t T_t ⊕ T_{t-1} ⊕ ... ⊕ T_0。S的含义:S_t代表了从第0步到第t步所有变换累积起来的等效单一变换。如果我们把S_t写成(A_cumulative_t, B_cumulative_t)那么它的意思是h_t A_cumulative_t * h_{-1} B_cumulative_t与h的关系: 因为我们假设初始状态h_{-1} 0所以上式简化为h_t B_cumulative_t最终推论:我们想要计算的隐藏状态h_t正好就是对操作单元序列T执行并行扫描前缀和后得到的累积变换序列S中每个元素的第二个分量。算法实现: 并行扫描算法通过log(N)次并行的“上扫”和log(N)次并行的“下扫”步骤来高效地计算出完整的S序列从而得到所有h。上扫并行地构建出一棵包含所有“区间和”的树状结构。下扫并行地利用这棵树自顶向下传播“前缀”信息最终得到每个位置的精确前缀和。总结: 通过将Mamba复杂的时变递归巧妙地定义为一个满足结合律的代数操作⊕我们成功地将其转化为了一个可以使用并行扫描算法解决的经典问题。算法的输出S与我们最终想求的隐藏状态h之间存在直接的、简单的映射关系。上扫和下扫操作复杂度为LogN其他操作复杂度都是O(1)而RNN复杂度O(N)因此减少了训练时间。QAMamba主要思路首先时不变系统参数A,B,C与输入t无关复杂度为O(1)无激活缺点是现实世界问题大多是时变且非线性的串行时变系统复杂度为O(N)mamba将两者结合起来创建了一个半时变系统其中B,C是时变的而A是固定的又通过delta引入了时变性经过公式推导可以达到logN的复杂度同时引入了时不变系统中缺少的“时变”和“非线性”因素。Mamba示意图中的*是矩阵乘法还是点乘点乘只有点乘才能实现“忘记”和“记忆”操作delta才能调整忘记和记忆的多少隐状态的shape为什么是(bs,d_inner,d_state)而没有t(seq_len)维度因为这个变量存储的是当前时刻t的隐状态在selectivescan中会不断产生隐状态且可以保存delta对A的影响是什么A由于负号和exp操作恒小于0而delta由于softplus激活恒大于0两者点乘后经过exp得到A_bar_t 。隐状态更新公式h_t A_bar_t * h_{t-1} B_bar_t * x_t。因此delta越大指数部分越小A_bar_t 越接近0代表历史信息被大多遗忘当delta越小A_bar_t 越接近1代表历史信息被大多记忆delta对B的影响是什么delta与B直接点乘因此delta越大B越大而隐状态更新公式可以看到B越大输入信息记忆越多delta越小输入信息记忆越少。这刚好与对A的影响相反这相当于权衡历史信息记忆和当前输入信息记忆mamba中大量使用了广播机制为什么不直接生成全尺寸的中间变量呢A代表普适的演化规则且mamba公式中A与t无关也就是输入无关因此A应该设置的没有bs和t维度。由于deltaB,C都是与输入有关的用来进行选择所以要由输入变换过来因此带有bst维度。且避免参数爆炸既然delta与输入有关那为什么delta的第三个维度要映射为d_in而不是d_in * d_state来直接与A对齐避免参数爆炸直接对其的linear层(d_inner, d_inner * d_state)mamba设计了两个linear层进行低秩分解(d_inner, dt_rank),(dt_rank,d_inner)而d_inner一般设置为d_model两倍是较大数值因此应避免(d_inner,d_inner)的设置mamba的门控激活函数为什么为SiLUSigmoid(x) 1 / (1 exp(-x))SiLU(x) x * Sigmoid(x)是两个经典门控激活。且SiLU能放大大的正输入能抑制大的负输入不仅能开关且无上限能放大重要信息因此常用。为了得到delta用到了dt_rank是什么含义为什么要这么设计低秩分解技术防止参数太多与其他分支参数不平衡。如果不使用dt_rank,delta要用线性层(d_inner,d_inner)之前说过d_inner值很大参数太大为什么生成B,C,delta的过程只有delta使用了低秩分解因为B,C生成的线性层是(d_inner,d_state)d_state一般为固定很小的数如16因此参数很少生成B,C,delta的过程为什么要设计成先linear再split再将delta分支经过linear的操作因为delta需要低秩分解其他两个不需要进入SSM之前的Conv以及激活进入后经过linearsplit再将delta经过linear的设计动机在哪里Conv是为了在进入网络前先融合信息聚合局部上下文可以看Mamba结构图历史进入SSM前都会经过ConvConv后激活的含义是引入非线性。所以这本质上就是一个Conv-Activate-linear操作只是由于delta要低秩分解因此有了最后一个lineard_model,d_inner,d_state的值都是什么含义一般设置为多少有什么关系d_model是embedding维度也是常规方法的映射中间维度另外两个是mamba独有的维度。d_inner 是在单个 MambaBlock 内部进行计算时临时扩展到的维度d_inner d_model * expand物理含义是mambablock的特征提取能力。d_state 是 Mamba 中最独特、也最关键的参数。它只存在于 SSM 的核心部分。它代表了潜状态Latent Stateh 的维度通常是一个较小的固定数如16物理含义是模型的记忆能力。为什么要设置d_innerd_inner为什么不设置的和d_model一样实验发现d_inner设置为d_model两倍mambablock性能更好。且除了mambablock之外还有许多其他操作如backbone后处理这些依然可以使用传统方法的dmodel而mamba此时只是其中的插件

更多文章