Skip to main content

Text2SQL × GRPO 实战:踩了三个大坑之后学到的教训

原文:用 GRPO 训练 Text2SQL:三个差点没爬出来的坑

这篇文章是基于知乎原文的深度扩写。我以第一视角补充了大量技术细节、直观解释和工程教训——用一个医疗 Text2SQL 项目,把 GRPO RL 训练的三大致命坑摊开来拆解清楚。


这篇文章说了什么

这是我第一个 Text2SQL 项目——用 GRPO 训练一个开源模型,把医生的自然语言问题自动转成 SQL 查询。场景是医疗数据分析:客户有几十个业务数据库,表名是拼音缩写、字段命名混乱、还有大量业务特有的 JOIN 逻辑,分析师每天手写 SQL 写到崩溃。

因为数据库内容涉及患者隐私,合规要求所有 prompt 不能出内网,只能用本地部署的开源模型。正好我们这个场景有大量历史 SQL 日志和数据库,但缺少「自然语言→SQL」的配对标注——找分析师手动标一条要 5-10 分钟,根本标不起。

GRPO 恰好适合这种情况:不需要人工标注,让模型自己生成 SQL,到数据库上跑一遍,结果对了就奖励、错了就惩罚。但这个过程不是一帆风顺的——我踩了三个差点没爬出来的坑。


两个回合下来,有效的实验就这几招

方法做了什么效果
Ground Truth 执行奖励SQL 在真实数据库上执行,结果跟标准答案比对,F1 软评分(部分正确的也有分)这是整个训练的生命线——没有它 RL 就是低效版 SFT
代理奖励辅助收敛格式规范 + 语法检查 + Schema 匹配 + N-gram 相似度,权重各 0.1~0.5提供连续梯度,引导方向,execution reward=0 时模型也知道往哪走
DAPO + Cosine LR + KL 约束DAPO 算法 + cosine 学习率衰减 + KL 散度约束 β=0.01既不过拟合(Dr.GRPO KL 涨 38 倍),也不学不动(REINFORCE++ 太保守)
SQL 执行超时防护ThreadPoolExecutor + .result(timeout=30) 包裹每次 SQL 执行防一条笛卡尔积 SQL 拖死整个 8 卡 DDP 训练
两层 RAG Agent第一层业务路由(关键词+embedding),第二层 Schema 检索 + few-shot复杂查询从 40% 拉到 87%(训练后),zero-shot 87% 超训练前 few-shot 78%

试了但翻车的方法:Dr.GRPO(KL 爆炸 38 倍)、REINFORCE++(太保守学不动)、裸 GRPO(前期没 gradient,代理奖励撑不住方向)。


一、先搞懂这些概念

1.1 Text2SQL 是什么

Text2SQL 就是把自然语言转成 SQL 查询。举个例子:

用户输入:
"2024 年 3 月门诊部接诊了多少 65 岁以上的糖尿病患者?"

期望输出:
SELECT COUNT(DISTINCT patient_id)
FROM outpatient_visits
WHERE visit_date BETWEEN '2024-03-01' AND '2024-03-31'
AND age > 65
AND diagnosis LIKE '%糖尿病%'

看起来简单,但实际场景里这个查询涉及的业务规则远超写 SQL 本身——哪张表留存诊信息(outpatient_visits 还是 mz_jz 这种拼音缩写)?年龄字段叫 age 还是 nl(年龄)还是藏在患者主表里?诊断字段存的是 ICD 编码还是中文文本?

Text2SQL 的本质难点不是"写 SQL 语法",而是把用户模糊的自然语言问题,映射到一个你从没见过的、命名随意的数据库 schema 上

1.2 Ground Truth 奖励 vs 代理奖励

在 GRPO 训练里,奖励函数是灵魂。这个项目里我设计了两类奖励信号,分别扮演不同角色:

Ground Truth 奖励(不可替代):执行正确性。 生成的 SQL 在真实数据库上跑一遍,查询结果跟标准答案比对。这是唯一能回答"SQL 写对了没有"的信号。我给的是 F1 软评分——一条返回了 100 行中 99 行正确的 SQL,能拿到 0.99 分,而不是跟完全写错的一样拿 0 分。

代理奖励(辅助平滑梯度):四个维度。

为了让你看懂每个奖励在干什么,我用同一条 prompt 来演示。假设标准答案 SQL 是:

-- prompt:「上个月门诊部开了多少盒二甲双胍?」
-- 标准答案(Ground Truth):
SELECT SUM(ypsl)
FROM mz_cf
WHERE cfrq BETWEEN '2024-05-01' AND '2024-05-31'
AND ypid = 'METFORMIN_500'

模型生成了一条错得千奇百怪的 SQL:

-- 模型输出:
wrong output: SELECT COUNT(*) FROM yf_kc WHERE rksj > '2024-05-01' AND ypmc LIKE '二甲双胍'

现在用四个代理奖励给这条 SQL 打分:

代理奖励怎么算这条 SQL 得分为什么
格式规范输出是否包含 <reasoning> + <answer> XML 标签0「wrong output:」不能改成 XML,格式不对
语法正确SQL 能不能被 parser 解析(CREATE TABLE 后直接 SELECT 不报错)1SELECT COUNT(*) FROM yf_kc WHERE ... 语法没问题
Schema 匹配模型用的表名/列名跟标准答案的 Jaccard 相似度0.2标准答案的 token 集 = {mz_cf, ypsl, cfrq, ypid, METFORMIN_500};模型用了 {yf_kc, rksj, ypmc},交集为空,但 WHEREAND 等 SQL 关键词共同命中,拿到一点分
N-gram 相似度跟标准 SQL 的 bigram(两词序列)重合度0.15标准 SQL 的 bigram = {SELECT SUM, SUM(, ypsl FROM, FROM mz_cf, ...};模型 bigram = {SELECT COUNT, COUNT(, FROM yf_kc, ...},SELECTFROM 相关 bigram 有部分重合

最终这条 SQL 的总 reward = 3.0×0(执行正确性)+ 0.1×0 + 0.2×1 + 0.3×0.2 + 0.2×0.15 = 0.29。Ground Truth 权重 3.0 占大头,但因为表全选错了,执行结果跟标准答案对不上=0。代理奖励提供了 0.29 的连续信号,告诉模型「语法对了、但表和字段全选错了,方向需要大调」。

这就是代理奖励的定位:在 Ground Truth 为零的时候,提供连续梯度告诉你错在哪、有多接近。

为什么需要代理奖励?打个比方你就明白了:

想象你在教一个人打靶,但他的靶子蒙着一层布——你只看得到最终结果(靶纸送过来才知道中没中)。 「执行正确性」就是这张靶纸——它是唯一能回答「打没打中」的信号。但问题是:靶纸要等整轮打完之后才能拿到,中间每一枪你不知道偏了多少。

「代理奖励」就是你在旁边看他端枪的姿势。扳机扣动的时机不对(语法错误)——扣分;枪口朝向明显歪了(表名都没选对)——扣分;但枪口方向大致对了、只是手腕微调的问题(表名对了、JOIN 逻辑接近)——加分,告诉他「方向对了,改细节就行」。

没有代理奖励的训练:模型生成了 100 条 SQL,95 条执行后结果不对(reward=0),5 条对了。模型只知道 5 条是好的,剩下 95 条错在哪了?完全不知道——就像靶纸上只有一个「中」或「没中」,没有任何关于偏了多少的信息。

有代理奖励的训练:那 95 条错的 SQL 也能拿到部分分数——「表名选对了(+0.3)但 JOIN 错了」「语法没问题(+0.2)但 WHERE 条件错了」。模型就能知道:「我离正确的方向还差多远,该调哪里。」

这就是两类奖励的关系:Ground Truth 告诉你什么是正确答案,代理奖励告诉你错了的时候错在哪里、有多接近。 只有前者,模型在黑暗中摸索;只有后者,模型在学表面模式而不是真正解题。

权重分配:执行正确性权重 3.0,四个代理奖励各 0.1~0.5。 保证模型优先学"写对 SQL",而不是"写得像标准答案"。

1.2.1 Advantage 和梯度——分数怎么变成学习信号

搞清楚奖励怎么打分之后,下一步是:这些分数怎么变成模型的学习信号?这里有两个核心概念。

Advantage(优势值):GRPO 不关心 0.9 够不够好,只关心 0.9 比组内其他人好多少。每次对同一个 prompt 生成 8 个答案,算出各自的得分后,用公式算每个答案的 advantage:

advantage = (自己的分数 - 8 个答案的平均分) / 8 个答案的标准差

advantage 越大,说明这条答案在组内越突出,模型越应该往这个方向学。advantage 为负,说明这条答案比组内平均还差,模型应该远离它。

梯度(Gradient):模型参数该往哪个方向改、改多大。advantage 告诉你方向(向好答案靠还是远离坏答案),梯度算出力度(这一步迈多大)。

拿一条具体 prompt 的例子走一遍:

Prompt:「上个月门诊部开了多少盒二甲双胍?」
8 个答案的得分:[0.95, 0.30, 0.85, 0.42, 0.78, 0.25, 0.91, 0.38]

均值 = 0.605,标准差 = 0.278
答案得分advantage模型学到什么
第 1 个0.95(0.95-0.605)/0.278 = +1.24梯度往这个方向推,strong positive
第 7 个0.91(0.91-0.605)/0.278 = +1.10也往这个方向推
第 5 个0.78(0.78-0.605)/0.278 = +0.63弱 positive
第 6 个0.25(0.25-0.605)/0.278 = -1.28梯度远离这个方向,strong negative

这就是 GRPO 的工作原理:不依赖绝对分数有多精确,只依赖组内的相对差异。


但这里有一个致命暗坑:组内全部趋同。

训练中后期,模型已经很稳了,8 个答案的得分全是 0.8:

[0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8]
均值 = 0.8,标准差 = 0
advantage = (0.8 - 0.8) / 0 = NaN(除零)
梯度 = 0

标准差为零 → 全部 advantage 等于 NaN → 这一步训练白费。不管 8 个答案质量如何,只要它们得分一样,模型就不会更新。

我的实验里 frac_reward_zero_std 一度飙到 67%。意味着每三步训练就有两步在浪费 8 张 H100 的算力——GRPO 跑了一晚上,其中超过一半的时间在做无用功。

怎么办? 核心思路是强制制造组内方差。解法有两个:

解法一:DAPO 的动态采样。 当检测到 reward 标准差持续过小时,自动增大 num_generations(如 8→16)。样本多了,组内更容易出现差异。

解法二:注入虚拟满分。 这是最便宜也最直接的一个技巧,用数字走一遍你就明白了。

刚才那个死局的场景:8 个答案全是 0.8,标准差为零,梯度 NaN,死锁。

死局:
[0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8]
均值 = 0.8,标准差 = 0
→ advantage 全部 NaN → 这一步白算

虚拟满分的做法:在每组末尾追加一条 reward=1.0 的"假答案"。 这条假答案不参与 loss 计算(它不是真的模型输出),只参与 advantage 计算(让标准差不再是零)。

破局:
[0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 1.0]
↑ 追加的虚拟满分
均值 = 0.822,标准差 = 0.067
→ 每个 0.8 的 advantage = (0.8 - 0.822) / 0.067 = -0.33

加了这条 1.0 之后,标准差从 0 变成了 0.067。原来 8 个 0.8 的答案现在全部算出一个 slight negative advantage(-0.33),模型会往远离这些"平平庸庸的 0.8"的方向微调。

为什么这个 trick 有效? 因为 advantage 的相对性——模型不关心 0.8 够不够好,只关心 0.8 比组内其他人差多少。虚拟满分告诉模型:「你这些 0.8 的答案还不是最好的,还有提升空间。」模型就会继续努力往满分方向靠,而不是因为组内全同就原地不动。

代价是什么? 虚拟满分是一个固定锚点,加了之后所有真实答案的 advantage 都会偏低(因为被 1.0 比下去了)。这会让梯度偏保守——学得慢一点,但不会死锁。在「学得慢」和「完全不学」之间,选慢不选停。

但这里有一个更重要的问题:虚拟满分不会把模型带偏吗?

会,如果你理解错了它的作用方式。虚拟满分只告诉模型「你这些 0.8 的答案还不够好」,但它不告诉模型哪个方向是好方向。模型远离 0.8 不代表它在往 1.0 走——它只是往一个随机的方向挪了一步。

那为什么还用?因为它破的不是「方向问题」,而是「死锁问题」。关键在这里:这一步挪完之后,下一轮训练 Ground Truth 执行奖励又会重新打分、重新算 advantage。 只要 Ground Truth 信号在正常工作,模型挪出去之后,下一轮自然会收到真实的反馈——「这个方向偏了,改回来」或「方向对了,继续」。虚拟满分只是帮模型从「动不了」变成「能动」,后续的方向纠偏靠的是 Ground Truth。

什么时候不能用? 这就是第一个坑的教训——如果你的 Ground Truth 执行奖励没有正常工作(比如数据库路径配错了),虚拟满分会让模型在错误的方向上随机游走,越训越差。它只能在「Ground Truth 正常、只是组内方差消失」的场景下用。本质上,它是给已经会解题的模型注射一针兴奋剂让它继续练,不是给还不会解题的模型教它怎么解题。

1.3 DAPO 是什么

DAPO(Dynamic Sampling Policy Optimization)是 GRPO 的改进变体。要理解它解决了什么问题,先看一个具体的数字例子。

GRPO 每次训练,对同一个 prompt 生成 N 个答案,然后比较它们的优劣来算梯度。问题是:如果 N 太小,生成的答案可能全对或全错——组内没差异,这一步训练白费。

举个实际例子。假设模型在某条 prompt 上的正确率是 60%:

num_generations全对概率全错概率「白费」概率
40.6⁴ = 13.0%0.4⁴ = 2.6%15.6%
80.6⁸ = 1.7%0.4⁸ = 0.07%1.7%

换成白话:num_generations=4 时,每训练 6-7 步就有 1 步白算——8 张 H100 跑一个小时,其中 10 分钟在算空气。num_generations=8 时,这个概率降到 1.7%,基本不浪费。

DAPO 除了增大采样数,还做了两个关键改进:

  • 动态采样策略:根据当前 reward 分布自动调整采样数——模型还不稳时多采样(避免运气成分),稳了之后少采样(省算力)。
  • 更强的 KL 散度约束:训练过程中不断检查「模型偏离原始策略多远」,偏离超过阈值就拉回来。这是为了防止模型为了刷高 reward 而写出格式对、语义错的 SQL(第三个坑里会详细讲)。

我最终选 DAPO + cosine LR + KL 约束 β=0.01 作为基础算法。

1.4 RAG Agent 是什么

RAG(Retrieval-Augmented Generation)= 先检索相关信息,再让模型基于这些信息生成答案。

这个项目里模型需要知道"这个问题该查哪些表和字段"。但几十个数据库、上百张表的 schema 拼起来远超上下文窗口。RAG Agent 的做法是:先把用户问题做一次"业务路由"(判断属于门诊库还是药房库),然后在目标数据库内检索最相关的表和字段,只把这些相关 schema 塞进 prompt。

用一条具体查询走一遍,你就知道它为什么有必要了。 假设分析师问:

「上个月门诊部开了多少盒头孢?」

这条问题背后涉及的数据库结构是这样的:

候选数据库候选表相关字段
门诊库 (30 张表)mz_cf(处方表)、mz_jz(就诊表)、mz_yp(药品字典)等cfrq(处方日期)、ypmc(药品名称)、ypsl(药品数量)
药房库 (25 张表)yf_kc(库存表)、yf_rk(入库表)、yf_ck(出库表)等rksj(入库时间)、cksj(出库时间)
住院库 (40 张表)zy_yz(医嘱表)、zy_br(患者表)等

如果眉毛胡子一把抓,把 95 张表的 schema 全塞进 prompt:

  • 首先,token 数直接爆上下文窗口。
  • 其次,模型会被一堆无关信息干扰——"药房库存"、"住院医嘱"跟「开了多少盒头孢」毫无关系。

RAG Agent 做的事:先识别「门诊」关键词 + embedding 相似度 → 锁定门诊库 → 在门诊库 30 张表里检索跟「头孢」「开药」最相关的 3-5 张表 → 最终塞进 prompt 的只有 3-5 张表的 schema + 3 条历史类似 SQL 示例 + 原始问题。上下文从「整个医院」缩小到「门诊开药」。


二、任务定义

  • 场景:医疗数据分析 Text2SQL。客户有几十个数据库,表名拼音缩写、字段命名不规范、JOIN 逻辑是业务特有的
  • 约束:包含患者隐私数据的 schema 不能出内网 → 必须本地部署开源模型
  • 数据:有历史 SQL 日志和标准执行结果,但没有 NL→SQL 配对标注
  • 模型:Qwen2.5-Coder-32B-Instruct,4bit 量化 + LoRA,8 张 H100

三、三个差点没爬出来的坑

第一个坑:数据库路径配错了——RL 在没有真实反馈的情况下跑了整晚

跑了几种 GRPO 变体同时对比(Dr.GRPO、DAPO、REINFORCE++),扔到 8 卡上跑了一晚上。早上看 metrics 曲线:schema 匹配度从 0.50 涨到 0.65,ngram 相似度也在涨。 一切看起来在变好。

然后我发现一个事实:执行正确性始终是 0。

查了一个小时才发现——训练脚本里数据库路径配错了,SQL 根本没有真正执行过。模型收到的所有反馈都来自代理奖励。schema 匹配度涨了,只是因为模型学会了一种模式匹配——"看到这种问题,就选这几张表"。

这不叫学会写 SQL,这叫学会了猜表名。

就像学游泳的人在陆地上对着视频练动作——姿势看起来像了,但一下水就沉。模型学到的不是 SQL 的语义正确性,而是「看到这种问题描述,就挑这几张表」的概率映射。

具体地说,训练了一个晚上之后,模型面对「上个月门诊部开了多少盒二甲双胍」这条 prompt,输出是这样的:

-- 训练后的模型输出(执行正确性始终为 0 的情况下)
SELECT SUM(ypsl)
FROM mz_cf
WHERE cfrq BETWEEN '2024-05-01' AND '2024-05-31'
AND ypid = 'METFORMIN_500'

看着是不是很对?表名 mz_cf 对了,日期过滤对了,药品 ID 也有了。但这一条 SQL 从来没能真正在数据库上跑过——表名对了只是因为模型学会了「看到"门诊"就输出 mz_cf」的模式匹配。一旦问题稍微变一下,比如「上个月门诊部哪些患者同时开了二甲双胍和胰岛素」,模型就露馅了:

-- 稍微变一下 prompt,模型就写出完全不靠谱的 JOIN
SELECT p.name
FROM mz_cf c
JOIN patient p ON c.ypid = p.name
WHERE c.cfrq BETWEEN '2024-05-01' AND '2024-05-31'
AND c.ypid = 'METFORMIN_500' AND c.ypid = 'INSULIN_GLARGINE'

JOIN ON c.ypid = p.name 把药品 ID 跟患者姓名 JOIN——语法对、表都用对了、但逻辑完全不通。更致命的是 ypid = 'METFORMIN_500' AND ypid = 'INSULIN_GLARGINE'——一个字段同时等于两个不同值,这条 SQL 永远不会返回行。这就是只有代理奖励没有 Ground Truth 的下场:模型学会了用正确的积木,但不知道怎么搭。

这个坑直接验证了奖励函数设计的一个关键原则:代理奖励能引导方向,但不能替代 Ground Truth。 没有执行奖励的 RL 训练,本质上退化成了低效版 SFT——在模仿标准答案的表面特征,不是在学语义正确的 SQL。

教训:开 RL 训练前,先跑一轮 smoke test——随机挑 10 条数据,确认 Ground Truth 奖励信号在正常工作。 一个路径错误浪费了一晚上 8×H100 的算力。


第二个坑:一条笛卡尔积 SQL 搞崩了整个 8 卡训练

修好数据库路径后重新训练。execution reward 终于从 0 开始涨了。

跑到 22%,整个 8 卡训练突然挂了。 没有任何预警,所有进程一起 SIGABRT。

查日志看到这行:

WorkNCCL(SeqNum=7034, OpType=_ALLGATHER_BASE)
ran for 1800063ms before timing out

翻译:某个 GPU 在一个同步操作上卡了 30 分钟。其他 7 张卡算完了在等它,集体超时退出。

根因:模型生成了一条笛卡尔积 SQL:

-- 这条 SQL 就是崩掉 8 张 H100 的罪魁祸首
SELECT *
FROM mz_cf c, mz_jz j, mz_yp y
WHERE c.cfrq BETWEEN '2024-05-01' AND '2024-05-31'

三张大表交叉连接(CROSS JOIN)——处方表 mz_cf 10 万行 × 就诊表 mz_jz 50 万行 × 药品字典 mz_yp 8000 行 = 4000 亿行的中间结果集。SQLite 老老实实地在算这张卡上的笛卡尔积,直到其他 7 张卡的 NCCL watchdog 判定超时。

DDP 训练的致命弱点:一个 rank 卡住 = 全体卡住。 这不是 GRPO 的 bug,这是分布式训练的基础约束——ALL_GATHER 需要所有 rank 的数据到齐才能继续。就像 8 个人一起搬桌子,7 个人搬起来了,1 个人不动——桌子就是抬不上去。一个 rank 在 IO 里卡 30 分钟,另外 7 个就只能等 30 分钟。

作为一个 Text2SQL 新手,修了三次才搞定:

第一次:signal.alarm(30) 设超时。 失败了。SIGALRM 只能发给主线程,DDP 的 worker 线程收不到信号,SQL 执行线程完全不受影响。

第二次:threading.Timer 又失败了。SQLite 连接不是线程安全的,跨线程关闭连接会导致 segfault——没有优雅降级,直接进程崩溃。

第三次:ThreadPoolExecutor 终于稳了。

from concurrent.futures import ThreadPoolExecutor, TimeoutError

executor = ThreadPoolExecutor(max_workers=1)

def execute_with_timeout(db_path, sql, timeout=30):
try:
future = executor.submit(run_sql, db_path, sql)
result = future.result(timeout=timeout)
return result
except TimeoutError:
return None # reward = 0,这条 SQL 不参与训练

把每次 SQL 执行扔到独立线程,用 .result(timeout=30) 等结果。超时就认栽,返回 None,reward=0,不会拖死其他 rank。

教训:RL 训练里跑外部程序(SQL 执行、代码运行、API 调用),一定要加超时。尤其是 DDP 多卡环境,一个 rank 的阻塞会传染给集群里所有 rank。


第三个坑:算法没选对,KL 散度爆炸了 38 倍

GRPO 有几个有名的变体,我一开始无脑全试了,想对比哪个最好。

Dr.GRPO:学得飞快,但代价是 KL 爆炸。

在讲这个坑之前,先说一下 KL 散度是什么意思。KL 散度衡量的是「现在的模型跟原始模型变得有多不一样」。 你可以把它想象成一条橡皮筋——原始模型在橡皮筋的一端,训练过程中模型越学越远,橡皮筋就拉得越紧。0.00028 是训练刚开始时的松弛状态,0.0108 是橡皮筋被拉到接近崩断的状态。一旦崩断,模型就开始胡言乱语——语法看起来对,但内容已经脱离逻辑。

Dr.GRPO 去掉了组内标准差归一化——相当于给梯度加了杠杆。这就像开车:标准 GRPO 踩油门是按固定力度,Dr.GRPO 是每次把油门踩的力度乘以当前速度——前期开得确实快,reward 蹭蹭涨。但到了中后期,KL 散度从 0.00028 飙到了 0.0108,涨了 38 倍

这时候模型开始出现诡异行为——生成一些语法正确但语义荒谬的 SQL。比如把患者年龄跟药品库存做 JOIN:

-- 这条语法完全正确,但在逻辑上是扯淡
SELECT AVG(p.age), SUM(s.stock_qty)
FROM patient p
JOIN drug_stock s ON p.age = s.stock_qty

为什么会写出这种东西?因为模型为了刷高 reward,学会了钻规则的空子——语法检查满分、格式规范满分、甚至 N-gram 跟标准答案相似度也在涨。但语义已经完全脱离了数据库的真实逻辑关系。

就像学生摸清了选择题的出题规律:虽然不知道知识点,但能靠「三长一短选最短」拿到高分。 —— KL 爆炸的本质就是模型在学应试技巧,而不是学解题能力。

REINFORCE++:稳如老狗,但也学不到东西。

跟 Dr.GRPO 相反,REINFORCE++ 极其保守。KL 散度几乎不变,reward 也纹丝不动。就像开着手刹踩油门——太稳了,稳到完全没进步。

最后选了 DAPO + 三个关键调整:

调整做了什么为什么
Cosine LR schedule前期正常学,后期自动降低学习率防止后期策略漂移失控——学得差不多了就降速
KL 约束 β=0.01策略偏离超阈值时,在 loss 里加惩罚项相当于给橡皮筋加了安全锁——拉到一定程度就自动拉回来
num_generations 4→8每个 prompt 生成 8 个答案做组内比较组内全对或全错的概率从 15.5% 降到 1.7%

这个组合兼顾了学习速度和稳定性:Dr.GRPO 是有进度但失控,REINFORCE++ 是稳但没进度,DAPO 在中间——进度看得见,橡皮筋拉不断。

教训:算法选择要平衡激进与保守。 学得快的容易崩(Dr.GRPO),太稳的学不动(REINFORCE++)。DAPO + cosine LR + KL 约束是一个不错的起点——它是中间地带,有进度但不失控。


四、推理侧:两层 RAG Agent

模型训好了,但实际推理时还有一个问题:客户有几十个数据库、上百张表,模型需要先知道"这个问题该查哪些表"才能写 SQL。把所有表的 schema 全部塞进 prompt 不可能——几十个库的 schema 拼起来远超上下文窗口。

我加了两层 RAG Agent。用一条具体查询走一遍完整流程:

用户问:「上个月门诊部头孢类药品的日均消耗量是多少?」

第一层:业务路由。 把用户问题跟业务领域做匹配。

匹配方式做了什么结果
关键词匹配"门诊部" → 门诊库门诊库得分最高
embedding 相似度整句跟各库描述向量算余弦相似度门诊库 0.91 > 药房库 0.34 > 住院库 0.12

这一步把 3 个候选库缩小到 1 个:门诊库。

第二层:Schema 检索。 在门诊库内用问题检索最相关的表和字段。

检索步骤具体操作
表检索用「头孢」「消耗量」「日」在门诊库 30 张表中召回 top-5 张:mz_cf(处方表)、mz_yp(药品字典)、mz_jz(就诊表)、mz_gh(挂号表)、mz_sf(收费表)
字段检索从上一步 5 张表的字段中筛选出最相关的 15-20 个字段
few-shot 检索从历史 SQL 日志中找出 3 条语义相似的查询作为示例

最终塞进 prompt 的内容只有三样:

1. 相关 schema 定义(5 张表的关键字段)
- mz_cf: cfrq(处方日期), ypid(药品ID), ypsl(数量)
- mz_yp: ypid(药品ID), ypmc(药品名称), yplb(药品类别)
- mz_jz: jzrq(就诊日期), patient_id, ksdm(科室代码)
...

2. 3 条 few-shot SQL 示例
- "上月门诊抗生素处方总量" → SELECT SUM(c.ypsl) FROM mz_cf c JOIN mz_yp y...
- "本周门诊接诊人次" → SELECT COUNT(*) FROM mz_jz WHERE jzrq...

3. 用户的自然语言问题
「上个月门诊部头孢类药品的日均消耗量是多少?」

效果上,RAG Agent 对复杂查询的提升最明显——简单查询模型 zero-shot 就能搞定,但涉及多表 JOIN、子查询的复杂场景,few-shot 示例引导着模型走通关键 JOIN 路径,效果天差地别。


五、最终结果

训练前(Base Qwen2.5-Coder-32B + RAG):

难度few-shot 准确率
简单78%
复杂40%

训练后(GRPO + RAG):

难度zero-shotfew-shot
简单87%95%
复杂62%87%

几个值得注意的对比:

  • zero-shot 87% 超过了训练前的 few-shot 78%。 GRPO 确实让模型学到了东西,而不只是靠 few-shot 示例撑着——模型内部的 SQL 生成能力真的变强了
  • 复杂查询提升最大:40%→87%(+47 点)。 多表 JOIN 和嵌套子查询是 RL 训练收益最大的地方——规则奖励函数对这些场景区分度太差,Ground Truth 执行奖励反而精准
  • few-shot 在训练后仍有 ~8% 的额外提升。 RAG Agent 和模型能力是互补的——模型本身写 SQL 强了,给几个参照示例还能再往上提一截

看一条复杂查询的训练前后对比,就能直观感受 +47 点提升长什么样。

用户问:「上个月门诊部开了降糖药的患者中,有多少人同时挂了内分泌科的号?」

这条查询涉及三表 JOIN(处方表 mz_cf + 药品字典 mz_yp + 挂号表 mz_gh),子查询过滤降糖药分类,再加 GROUP BY 去重。训练前后同一道题的答案是这样的:

-- 训练前(Base 模型,带了 few-shot 示例,仍然写错了)
SELECT COUNT(*)
FROM mz_cf c
JOIN mz_yp y ON c.ypid = y.ypid
WHERE c.cfrq BETWEEN '2024-05-01' AND '2024-05-31'
AND y.yplb LIKE '%降糖%'
AND patient_id IN (
SELECT patient_id FROM mz_gh WHERE ghks = '内分泌科'
)
-- 问题:挂号日期没限定,患者上个月挂了内分泌科但不一定是上个月开的降糖药
-- 训练后(GRPO + RAG,zero-shot,正确)
SELECT COUNT(DISTINCT c.patient_id)
FROM mz_cf c
JOIN mz_yp y ON c.ypid = y.ypid
JOIN mz_gh g ON c.patient_id = g.patient_id AND c.cfrq = g.ghrq
WHERE c.cfrq BETWEEN '2024-05-01' AND '2024-05-31'
AND y.yplb LIKE '%降糖%'
AND g.ghks = '内分泌科'

训练前模型把挂号跟处方分了两步子查询——看起来"有逻辑",但实际上没对齐日期,会把上个月开过降糖药但三个月前挂过内分泌科的人也数进去。训练后模型直接做了三表 JOIN 并显式对齐 c.cfrq = g.ghrq,把两件事锁定在同一天——这才是业务上正确的语义。

Ground Truth 执行奖励就是在这种场景里发力的:子查询的 SQL 执行后也能返回行(不会报错),但返回的行比正确答案多了——F1 软评分会捕捉到「结果接近但不精确」,扣掉分数,告诉模型「方向对但精确度不够」。


六、Insight 提炼

6.1 RL 的核心价值是 Ground Truth 奖励——代理奖励只是辅助

第一个坑是最贵的教训:数据库路径配错,代理奖励撑了整晚的"假训练"。Schema 匹配度涨了,但模型学会的是"看到问题 A 就选表 B"的模式匹配,不是真正的 SQL 生成能力。

代理奖励的作用是"在 ground truth 稀疏时提供连续梯度引导方向"。但如果只有代理奖励没有 ground truth,RL 就退化成了低效版 SFT——在模仿标准答案的表面特征。

开 RL 训练前必须做的 check:随机挑 10 条数据跑一轮,确认 Ground Truth 奖励信号在正常产生非零值。

6.2 DDP 环境下,一个 rank 的阻塞会传染全部

第二个坑是纯工程问题,但杀伤力极大——一条笛卡尔积 SQL 不需要有任何语法错误,就可以在 30 分钟内让 8 张 H100 全部停摆。

signal.alarm 只能发主线程、threading.Timer 会 segfault、最终 ThreadPoolExecutor + .result(timeout=30) 才是稳定方案。这个坑的教训可以推广到任何 RL 训练里跑外部程序(代码执行、API 调用、shell 命令)的场景——永远加超时,永远做好优雅降级。

6.3 算法选择要平衡激进与保守——DAPO 是中间地带

Dr.GRPO 学得飞快但 KL 爆炸 38 倍,REINFORCE++ 太保守学不动。DAPO + cosine LR + KL 约束 β=0.01 是性价比最高的组合。多一个维度:num_generations=8 把"组内全对或全错"的概率从 15.5% 压到 1.7%——组内无差异=这步训练白费,虽然看起来像调参细节,但对训练效率影响很大。

6.4 GRPO 的选择标准:什么时候该用

这个项目恰好展示了 GRPO 最适用的一类场景:当你没有标注数据,但有一个可以自动验证结果的"标准答案源"时。

Text2SQL 有数据库可以执行 SQL 并比对结果;代码生成有测试用例;数学推理有标准答案。这些场景的共同特点是:Ground Truth 可以通过程序自动获得,不需要人工标注。GRPO 就是为这类场景设计的——让模型自己试,试对了加分,试错了扣分。

如果你的场景没有自动验证的手段,RL 的 ground truth 奖励就得靠人工标注来提供——这时候 GRPO 的成本优势就没了,不如直接用 SFT。

原文链接:https://zhuanlan.zhihu.com/p/2015523496889964219