Swin Transformer中的相对位置偏置:从理论到代码实现(附PyTorch示例)

张开发
2026/4/16 1:53:16 15 分钟阅读

分享文章

Swin Transformer中的相对位置偏置:从理论到代码实现(附PyTorch示例)
Swin Transformer中的相对位置偏置从理论到代码实现附PyTorch示例在视觉Transformer领域Swin Transformer以其独特的层级式窗口注意力机制脱颖而出。其中相对位置偏置Relative Position Bias作为核心创新之一巧妙地解决了传统Transformer在视觉任务中忽视空间关系的痛点。本文将深入剖析这一技术的实现细节并通过可运行的PyTorch代码演示其完整工作流程。1. 相对位置偏置的数学本质传统自注意力机制计算query和key的点积时本质上是在度量两个token的语义相关性。但在视觉任务中像素或图像块之间的空间位置关系同样蕴含重要信息。相对位置偏置通过引入可学习的偏置矩阵B将几何先验注入注意力权重# 标准注意力计算公式含相对位置偏置 Attention(Q, K, V) Softmax(QK^T/√d B)V其中B ∈ ℝ^(n×n)的每个元素B_ij表示query位置i与key位置j的相对位置编码。与绝对位置编码不同这种设计具有以下优势平移等变性窗口移动时相同相对位置的偏置值保持一致计算效率共享偏置参数大幅减少参数量灵活性可学习机制能自适应不同任务的空间关系模式下表对比了三种主流位置编码方式的特性编码类型参数量平移等变长程依赖处理典型应用绝对位置编码O(n)❌✅ViT相对位置偏置O(n²)✅❌窗口内Swin Transformer旋转位置编码O(1)✅✅RoFormer2. Swin Transformer的实现解析2.1 偏置表的初始化Swin Transformer采用可学习的参数表存储不同相对位置的偏置值。对于窗口大小为M×M的配置实际需要覆盖的相对位置范围为[-M1, M-1]class SwinBlock(nn.Module): def __init__(self, window_size, num_heads): super().__init__() # 初始化偏置表(2M-1)×(2M-1) × num_heads self.relative_position_bias_table nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 构建相对位置索引后续详细解释 self._build_relative_position_index(window_size)以7×7窗口为例偏置表包含(2×7-1)²169个位置组合每个头有独立的偏置参数。这种设计使得模型可以学习不同注意力头关注不同的空间关系模式。2.2 位置索引的构建核心挑战在于如何将二维相对坐标映射到一维的偏置表索引Swin Transformer采用了一种巧妙的笛卡尔积编码方案def _build_relative_position_index(self, window_size): # 生成网格坐标 coords_h torch.arange(window_size[0]) coords_w torch.arange(window_size[1]) coords torch.stack(torch.meshgrid(coords_h, coords_w)) # 2, Wh, Ww # 计算相对坐标 coords_flatten torch.flatten(coords, 1) # 2, Wh*Ww relative_coords coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww # 坐标归一化到[0, 2M-2] relative_coords relative_coords.permute(1, 2, 0).contiguous() relative_coords[:, :, 0] window_size[0] - 1 relative_coords[:, :, 1] window_size[1] - 1 # 一维化编码 relative_coords[:, :, 0] * 2 * window_size[1] - 1 relative_position_index relative_coords.sum(-1) self.register_buffer(relative_position_index, relative_position_index)对于2×2窗口生成的索引矩阵如下所示tensor([[4, 3, 1, 0], [5, 4, 2, 1], [7, 6, 4, 3], [8, 7, 5, 4]])该矩阵的对称性反映了相对位置的双向特性——位置i到j的偏置与j到i的偏置虽然数值不同但存在系统性的对应关系。3. 前向传播中的动态偏置注入在实际计算注意力时相对位置偏置作为附加项参与计算def forward(self, x): B, N, C x.shape qkv self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) q, k, v qkv.unbind(2) # B, N, num_heads, head_dim # 标准注意力计算 attn (q k.transpose(-2, -1)) * self.scale # 注入相对位置偏置 relative_position_bias self.relative_position_bias_table[ self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww, Wh*Ww, nH attn attn relative_position_bias.permute(2, 0, 1).unsqueeze(0) attn attn.softmax(dim-1) return (attn v).transpose(1, 2).reshape(B, N, C)关键步骤解析通过索引矩阵从偏置表中查取对应位置的偏置值将偏置矩阵调整为与注意力分数相同的形状(nH, N, N)直接相加后参与softmax计算这种实现方式在保持高效的同时完美融入了空间位置信息。4. 实战自定义相对位置偏置模块以下是一个可复用的相对位置偏置模块实现支持任意窗口尺寸class RelativePositionBias(nn.Module): def __init__(self, window_size, num_heads): super().__init__() self.window_size window_size self.num_heads num_heads # 初始化偏置表 self.relative_position_bias_table nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 构建位置索引 coords_h torch.arange(window_size[0]) coords_w torch.arange(window_size[1]) coords torch.stack(torch.meshgrid([coords_h, coords_w])) coords_flatten torch.flatten(coords, 1) relative_coords coords_flatten[:, :, None] - coords_flatten[:, None, :] relative_coords relative_coords.permute(1, 2, 0).contiguous() relative_coords[:, :, 0] window_size[0] - 1 relative_coords[:, :, 1] window_size[1] - 1 relative_coords[:, :, 0] * 2 * window_size[1] - 1 relative_position_index relative_coords.sum(-1) self.register_buffer(relative_position_index, relative_position_index) # 初始化参数 trunc_normal_(self.relative_position_bias_table, std.02) def forward(self): relative_position_bias self.relative_position_bias_table[ self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) return relative_position_bias.permute(2, 0, 1).contiguous()使用示例# 初始化模块 window_size (7, 7) num_heads 8 rel_pos_bias RelativePositionBias(window_size, num_heads) # 在注意力计算中使用 attn_scores (q k.transpose(-2, -1)) / math.sqrt(dim) attn_scores attn_scores rel_pos_bias()实际项目中这个模块可以无缝集成到现有的Transformer架构中。我在多个视觉任务实验中发现合理设置窗口大小对模型性能影响显著——过小的窗口会限制感受野而过大的窗口则会导致偏置表过于稀疏。

更多文章