原文发表在知乎,辛苦移步:《部分离线强化学习相关的算法总结(td3+bc/conrft)》
最近看的一些在高复杂性,时长较长的场景中使用强化学习算法提升效果的案例,例如《关于gr-rl与pi-0.6(π₀.₆)的一些想法》,论文中展示的制作咖啡,穿鞋带等实验让人印象深刻。它们在技术上有一些共同的特点,即使用强化学习,先用离线数据进行offline预训练(不是传统的模型学习,而是强化学习中的offline),然后在online与环境进行交互学习。在第一阶段的离线学习完成后,可以有一个基本的策略与critic函数,这样在线学习阶段就可以减少交互时探索的空间,整体上两阶段的学习比一阶段的学习可以较大的降低总体的学习时长。第一阶段的offline强化学习是一个很大的topic,有很多经典的算法,例如CQL/IQL/Cal-QL/TD3+BC等。td3+bc也是gr-rl论文中在离线强化学习阶段使用的算法。如下图所示(来源于gr-rl论文),在离线阶段训练出的critic函数可以充当一个progress predictor,一方面可以充当一个任务的进度显示器,对数据中的噪声(下图中深蓝色曲线的波动部分)进行一些filter,另一方面在强化学习算法中可以对策略进行监督指导。
我们知道在线强化学习有一个核心的特征就是与环境env进行实时交互,获取env的反馈,在交互中持续学习提升。而离线强化学习没有与env进行交互,只是利用采集好的数据进行离线学习。将离线训练出的模型拿到在线环境中运行的话,很容易产生OOD的问题,就是离线的数据分布与在线的数据分布变化较大,举例来说,离线的数据可能是一些较老的策略产生的数据,而策略经过持续优化后,新策略与环境交互产生的数据会跟以前有较大的不同,可能是老策略不会遇到的新场景数据。所以离线的策略在在线的环境中一般并不能运行的特别好。
笔者不那么深度的研究了CQL/IQL/Cal-QL/TD3+BC几个offline rl领域经典的算法,也研究了ConRFT《RSS 2025|ConRFT: 真实环境下基于强化学习的VLA模型微调方法》论文,此论文其实类似于gr-rl/pistar0.6类似的训练范式,使用强化学习进行两阶段的训练。其中在离线强化学习阶段使用了Cal-QL算法。本篇重点记录一下TD3+BC和ConRFT两个工作的学习心得。
TD3+BC
算法介绍
参考:《离线强化学习(Offline RL)系列3: (算法篇) TD3+BC 算法详解与实现(经验篇)》,这个算法作者开源了代码,代码极简,学习起来很方便,只有300多行。核心思想就是在传统强化学习算法基础之上再叠加bc(行为克隆)的监督信号,这样可以没有与环境进行交互的情况下,策略产生的动作也不会偏离数据集太远。整体上在简单性和效果上来看,是一个比较优秀的算法。(根据论文和网络信息判断,目前笔者也在使用此算法进行一些实验)
算法复现
开源代码主要在mujoco一些离线数据集上(大概有10个左右)进行训练,训练的产出是一个策略模型和critic模型,将策略模型在仿真环境中进行评测,通过分数来观察offline训练的策略在仿真中的效果。笔者只是简单的让它在第一个数据集上跑起来,方便研究代码,过程如下。
数据集与仿真环境
先介绍一下数据集与仿真环境,它两是对应的,数据集也是在对应的仿真环境中采集的。第一个数据集对应的是half cheetah仿真环境:Gymnasium Documentation,介绍如下。
猎豹(HalfCheetah)是一款二维机器人,由9 个身体部件和8 个连接关节组成(包含两只爪子)。其任务目标为:通过向关节施加扭矩,使该猎豹机器人以最快速度向前(右侧)奔跑。奖励机制设定为:向前移动的距离对应正奖励,向后移动则对应负奖励。需要注意的是,该机器人的躯干与头部为固定结构,扭矩仅能施加于其余6 个关节,分别为:连接躯干的前后大腿关节、连接大腿的小腿关节,以及连接小腿的足部关节。
环境安装
论文作者只给出了软件版本号,网上也有一些网友的安装步骤,因为这个工作是5年前的,所以很多安装方法都无法成功,安装环境的话,请按下面这篇文章进行:离线强化学习(Offline RL)系列2: (环境篇)D4RL数据集简介、安装及错误解决。除此之外,笔者遇到一个新的问题:
Cython.Compiler.Errors.CompileError:
/home/ubuntu/miniconda3/envs/td3bc/lib/python3.7/site-packages/mujoco_py/cymj.pyx
解决方法:找到文件路径
/home/ubuntu/miniconda3/envs/td3bc/lib/python3.7/site-packages/mujoco_py/cymj.pyx
修改以下两处:
在 c_warning_callback 和 c_error_callback 的定义中添加 noexcept
cdef extern from “mujoco.h”:
cdef void (*mju_user_warning)(const char *) noexcept nogil
cdef void (*mju_user_error)(const char *) noexcept nogil
修改 c_warning_callback 和 c_error_callback 的定义
cdef void c_warning_callback(const char *msg) noexcept nogil:
# 现有代码
cdef void c_error_callback(const char *msg) noexcept nogil:
# 现有代码
笔者安装的软件版本与论文作者并不完全一样,笔者版本如下:
Python 3.7.0
mujoco 2.3.6
mujoco-py 2.1.2.14
gym 0.23.1
D4RL 1.1 /home/ubuntu/Downloads/embodient/d4rl
torch 1.13.1
torchvision 0.14.1
指标趋势
从下图可以看到,在中间的过程算法整体上收敛后有一个较好的效果(图中reward是在用离线数据训练出来的策略在仿真环境下面进行实际运行产生的奖励),后期又发散了。
ConRFT
介绍:《RSS 2025|ConRFT: 真实环境下基于强化学习的VLA模型微调方法》,此文章讲得挺清楚了。如上面所述,此工作也是一个两阶段的训练范式,第一阶段离线强化学习,第二阶段在线强化学习。在工程实现上,它基于hil-serl的框架,但hil-serl只是一个online的训练范式。
强化学习算法
ConRFT类似于td3+bc的思路,使用了CalQL+BC算法,都在监督信号中增加了行为克隆(BC)的监督,不同的是CalQL是另一种离线强化学习算法,它基于CQL算法,做了一些改进,而td3就是较为传统且经典的强化学习算法。对于Td3/CQL/CalQL等算法在此文不展开,大家可自行检索学习。
训练时长
此工作说“可以经过45-90分钟的在线强化学习后,任务成功率可达96%以上”,这个时间相对hil-serl的2小时左右更短一些。笔者认为其受益于上一阶段的离线强化学习的效果加持。ConRFT在评测任务上在场景复杂上其实也更简单一些,例如pick banana。而hil-serl的任务设计更多,复杂度更高一些,例如安装汽车仪表盘。
模型设计
critic网络就是较为简单的resnet提取图片特征然后叠加本体状态信息,经过MLP输出q值,这点与hil-serl(lerobot版本的实现)是一致的。在actor网络设计上,ConRFT使用了Octo-small模型,而hil-serl(lerobot版本的实现)使用的跟critic网络类似的结构。
工程实现
ConRFT是基于hil-serl论文开源的代码框架进行开发,在工程架构上,使用的是actor-learner双节点的模式,actor充当与环境进行交互探索,收集数据的角色,learner拿到数据后同步进行模型训练,定期将最新的参数同步给actor。上图中双buffer的数据缓存结构跟hil-serl也是完全一样的。另外,此框架在模型训练和推理上使用了jax框架,非大家熟悉的pytorch框架,所以笔者并不太想深入研究它。
关于hil-serl这种双节点的训练框架,lerobot也实现了一版,笔者前期对hil-serl进行复现也是基于lerobot的版本,详情可参考:《具身智能hil-serl强化学习算法在lerobot机械臂上复现》。lerobot版本将节点间的通信协议更新为grpc协议,同时将jax框架更新为pytorch框架,所以更加清晰易读一些。
上面的介绍比较宏观,结合了笔者的大体框架代码的阅读与gemini大模型对代码的分析。下面笔者列举一些与gemini的问答:
critic网络的损失函数是如何定义的?
Critic 网络的损失函数定义在 serl_launcher/serl_launcher/agents/continuous/conrft_single_octo_cp.py 文件中,主要由 calql_critic_loss_fn 方法实现。
这个损失函数由两部分组成:
- 标准的时序差分损失 (Temporal Difference Loss)
- 保守的 Q 学习惩罚 (Cal-QL / CQL Penalty)
最终的 Critic Loss 是这两部分的加权和:critic_loss = td_loss + alpha * cql_loss
- 时序差分损失 (TD Loss)
这部分在 critic_loss_fn 方法中计算,是强化学习中的标准损失,目标是让 Critic 的预测遵循贝尔曼方程。
其计算步骤如下:
计算目标 Q 值 (Target Q-Value):
- 首先,使用当前的策略(Policy)为数据集中的“下一个状态” (next_observations) 生成“下一个动作” (next_actions)。
- 然后,使用一个独立的、更新较慢的 Target Critic 网络 来计算这些“下一个状态-动作对”的 Q 值 (target_next_qs)。
- 关键一步:为了抑制 Q 值过高估计,算法会从集成的多个 Target Critic(Twin-Critic)中,选取最小的那个 Q 值 (target_next_min_q)。
- 最后,根据贝尔曼方程,用即时奖励 rewards 加上折扣后的未来期望回报(即最小的 target_next_min_q)来构建最终的目标 Q 值 target_q。
1 target_q = batch[“rewards”] + discount * batch[“masks”] * target_next_min_q
计算损失:
- 使用当前的 Critic 网络预测数据集中“当前状态-动作对”的 Q 值 (predicted_qs)。
- 损失函数就是预测 Q 值 (predicted_qs) 和目标 Q 值 (target_q) 之间的 均方误差 (Mean Squared Error)。
1 critic_loss = jnp.mean((predicted_qs - target_qs) ** 2)- 保守的 Q 学习惩罚 (Cal-QL Penalty)
这部分是 Cal-QL 算法的核心,旨在解决离线强化学习中的分布外(Out-of-Distribution, OOD)动作过高估计的问题。它通过一个额外的惩罚项,压低非数据集中动作的 Q 值,使得学习到的 Q 值更加保守和可靠。
这部分的计算在 _get_cql_q_diff 方法中:
采样 OOD 动作: 针对当前状态,从多个分布中采样大量动作,包括:
- 从一个均匀或正态分布中采样的随机动作。
- 从当前策略生成的当前动作和下一个动作。
计算 OOD 动作的 Q 值: 使用当前的 Critic 网络评估所有这些采样动作的 Q 值。
应用 Cal-QL 校准: 这是 “Calibrated Q-Learning” 的体现。在计算保守惩罚之前,代码会确保对 OOD 动作的 Q 值预测不会低于数据集中实际观测到的蒙特卡洛回报 (
mc_returns)。这防止了 Critic 对实际上很好的
OOD 动作进行过度惩罚。
1 # 将OOD动作的Q值限制在不低于实际观测到的回报
2 cql_q_samples = jnp.maximum(cql_q_samples, mc_lower_bound)计算保守惩罚项:
- 通过对所有 OOD 动作的 Q 值以及数据集中真实动作的 Q 值进行 logsumexp 操作,来估计当前状态下所有可能动作的 Q 值期望。
- cql_loss 就是这个期望值与数据集中真实动作 Q 值的差。如果 OOD 动作的 Q 值普遍较高,这个差值就会很大,从而产生一个较大的惩罚。
最终,这个 cql_loss 乘以一个超参数 alpha 后,被加到 TD Loss 上,共同构成了 Critic 的总损失,驱动 Critic 在拟合贝尔曼方程的同时,保持对未知动作的保守性。
actor的损失函数是如何定义的?
Actor 网络的损失函数定义在 serl_launcher/serl_launcher/agents/continuous/conrft_single_octo_cp.py 文件中的 policy_loss_fn 方法中。
这个 Actor 是一个基于扩散模型 (diffusion model) 的策略,因此其损失函数结合了行为克隆(或扩散重建)和 Q 值最大化的思想。它包含两个主要组成部分:
- 一致性损失 (Consistency Loss),也称为扩散损失或去噪损失。
- Q 值最大化损失 (Q-Value Maximization Loss)。
最终的 Actor Loss 是这两部分的加权和:
actor_loss = self.state.bc_weight * recon_loss + self.state.q_weight * q_loss
- 一致性损失 (recon_loss)
这个损失项旨在让 Actor 学习如何从带有噪声的动作中重建出原始的(专家或目标)动作。它类似于行为克隆,鼓励 Actor 模仿演示数据中的行为。
其计算步骤如下:
原始动作 (
x_start): 从批次数据中获取真实的专家动作(batch[“actions”])。添加噪声: 随机选择一个噪声水平 t 和一个随机噪声 noise,将噪声添加到原始动作 x_start 上,得到一个带噪声的动作 x_t:
1 x_t = x_start + noise * append_dims(t, dims)去噪预测 (
distiller): Actor 网络 (self.forward_policy) 的任务是接收这个带噪声的 x_t 和对应的噪声水平 t,然后预测出原始的、无噪声的动作 (distiller)。计算重建误差: 一致性损失 (recon_loss) 是 Actor 预测的去噪动作 distiller 和原始动作 x_start 之间的 均方误差 (MSE),并根据噪声水平 t 的信噪比 (SNR) 进行加权。
1 recon_diffs = (distiller - x_start) ** 2
2 recon_loss = (mean_flat(recon_diffs) * weights).mean()
这个 recon_loss 鼓励 Actor 准确地从噪声中恢复原始动作,从而隐式地学习专家行为。Q 值最大化损失 (q_loss)
这个损失项旨在通过最大化 Critic 评估的 Q 值来改进 Actor 的策略,使其选择能带来更高回报的动作。
其计算步骤如下:
- 生成新动作 (
new_actions): Actor 根据当前的观测 (batch[“observations”]) 和任务 (batch[“tasks”]) 生成新的动作。 - Critic 评估: 将 Actor 生成的这些 new_actions 送入 Critic 网络 (self.forward_critic) 进行评估,得到这些动作的 Q 值 (q_new_actions)。
- 最大化 Q 值: 由于优化器默认是最小化损失,为了实现 Q 值最大化,Actor 的 Q 值损失被定义为评估 Q 值的 负值:
1 q_loss = - q_new_actions.mean()
通过最小化 q_loss,Actor 实际上就是在最大化其生成动作的 Q 值。
综合 Actor 损失
最终的 Actor 损失 (actor_loss) 是 recon_loss 和 q_loss 的加权组合。其中:
- self.state.bc_weight 是控制一致性损失(行为克隆)重要性的权重。
- self.state.q_weight 是控制 Q 值最大化损失重要性的权重。
这种组合方式使得 Actor 能够同时利用演示数据学习专家行为,并通过 Critic 的反馈进一步优化策略,以期在环境中获得更高的回报。
在预训练时和hil在线训练时,critic和actor的更新频率有什么不同吗?
是的,预训练和在线训练阶段在 Critic 和 Actor 的更新机制上存在显著的不同。
总结来说,更新的频率结构是相同的,但所使用的损失函数和数据来源是不同的。
相同点:更新频率结构
在两个阶段中,代码都采用了相同的 Critic-to-Actor Ratio (cta_ratio=2) 结构。这意味着,无论是预训练还是在线训练,都会遵循“N-1 次只更新 Critic,然后进行 1 次同时更新 Actor 和 Critic”的模式。
这种让 Critic 比 Actor 更新更频繁的策略在两个阶段都得到了保持。
不同点:
- 损失函数的不同
这是最关键的区别,直接影响了学习的目标。
预训练 (Offline Pre-training):
- 使用 agent.update_calql 方法进行更新。
- 这会调用 calql_critic_loss_fn,该函数计算的是 Cal-QL (或 CQL) 损失。
- 目的: Cal-QL 是一种保守的 Q 学习算法,专门为离线数据集设计。它通过一个惩罚项来压低未在数据集中见过的“分布外”动作的 Q 值,防止过高估计,从而在纯离线数据上学到更可靠、更安全的策略。
在线训练 (Online HIL Training):
- 使用 agent.update_ql 方法进行更新。
- 这会调用 critic_loss_fn,该函数只计算标准的时序差分损失 (TD Loss)。
- 目的: 在线阶段由于有持续的 HIL 互动数据流入,策略可以进行探索,算法不再需要像纯离线时那样极端保守。因此,它切换回了标准的 Q 学习更新方式,去掉了 Cal-QL 的保守惩罚项。
- 数据来源的不同
预训练:
- 训练数据完全来自 demo_buffer,也就是预先收集好的专家演示数据集。这是一个纯粹的离线 (offline) 训练过程。
在线训练:
- 训练数据是混合的。每一批 (batch) 数据都由两部分拼接而成:
- 一半来自 replay_buffer(存储了与环境/HIL 实时交互产生的在线 (online) 数据)。
- 另一半来自 demo_buffer(专家演示数据)。
- 这种方式使得智能体在利用过往专家经验的同时,也能从新的在线探索和互动中学习。
- 训练数据是混合的。每一批 (batch) 数据都由两部分拼接而成:
综上所述,虽然更新频率的宏观结构(cta_ratio)保持不变,但两个阶段在学习目标(保守 vs. 标准 Q 学习)和数据来源(纯离线 vs. 混合在线)上存在根本性的差异。