Text2SQL × GRPO 实战:踩了三个大坑之后学到的教训
原文:用 GRPO 训练 Text2SQL:三个差点没爬出来的坑
这篇文章是基于知乎原文的深度扩写。我以第一视角补充了大量技术细节、直观解释和工程教训——用一个医疗 Text2SQL 项目,把 GRPO RL 训练的三大致命坑摊开来拆解清楚。
这篇文章说了什么
这是我第一个 Text2SQL 项目——用 GRPO 训练一个开源模型,把医生的自然语言问题自动转成 SQL 查询。场景是医疗数据分析:客户有几十个业务数据库,表名是拼音缩写、字段命名混乱、还有大量业务特有的 JOIN 逻辑,分析师每天手写 SQL 写到崩溃。
因为数据库内容涉及患者隐私,合规要求所有 prompt 不能出内网,只能用本地部署的开源模型。正好我们这个场景有大量历史 SQL 日志和数据库,但缺少「自然语言→SQL」的配对标注——找分析师手动标一条要 5-10 分钟,根本标不起。
GRPO 恰好适合这种情况:不需要人工标注,让模型自己生成 SQL,到数据库上跑一遍,结果对了就奖励、错了就惩罚。但这个过程不是一帆风顺的——我踩了三个差点没爬出来的坑。
两个回合下来,有效的实验就这几招
| 方法 | 做了什么 | 效果 |
|---|---|---|
| Ground Truth 执行奖励 | SQL 在真实数据库上执行,结果跟标准答案比对,F1 软评分(部分正确的也有分) | 这是整个训练的生命线——没有它 RL 就是低效版 SFT |
| 代理奖励辅助收敛 | 格式规范 + 语法检查 + Schema 匹配 + N-gram 相似度,权重各 0.1~0.5 | 提供连续梯度,引导方向,execution reward=0 时模型也知道往哪走 |
| DAPO + Cosine LR + KL 约束 | DAPO 算法 + cosine 学习率衰减 + KL 散度约束 β=0.01 | 既不过拟合(Dr.GRPO KL 涨 38 倍),也不学不动(REINFORCE++ 太保守) |
| SQL 执行超时防护 | ThreadPoolExecutor + .result(timeout=30) 包裹每次 SQL 执行 | 防一条笛卡尔积 SQL 拖死整个 8 卡 DDP 训练 |
| 两层 RAG Agent | 第一层业务路由(关键词+embedding),第二层 Schema 检索 + few-shot | 复杂查询从 40% 拉到 87%(训练后),zero-shot 87% 超训练前 few-shot 78% |
试了但翻车的方法:Dr.GRPO(KL 爆炸 38 倍)、REINFORCE++(太保守学不动)、裸 GRPO(前期没 gradient,代理奖励撑不住方向)。
一、先搞懂这些概念
1.1 Text2SQL 是什么
Text2SQL 就是把自然语言转成 SQL 查询。举个例子:
用户输入:
"2024 年 3 月门诊部接诊了多少 65 岁以上的糖尿病患者?"
期望输出:
SELECT COUNT(DISTINCT patient_id)
FROM outpatient_visits
WHERE visit_date BETWEEN '2024-03-01' AND '2024-03-31'
AND age > 65
AND diagnosis LIKE '%糖尿病%'
看起来简单,但实际场景里这个查询涉及的业务规则远超写 SQL 本身——哪张表留存诊信息(outpatient_visits 还是 mz_jz 这种拼音缩写)?年龄字段叫 age 还是 nl(年龄)还是藏在患者主表里?诊断字段存的是 ICD 编码还是中文文本?
Text2SQL 的本质难点不是"写 SQL 语法",而是把用户模糊的自然语言问题,映射到一个你从没见过的、命名随意的数据库 schema 上。
1.2 Ground Truth 奖励 vs 代理奖励
在 GRPO 训练里,奖励函数是灵魂。这个项目里我设计了两类奖励信号,分别扮演不同角色:
Ground Truth 奖励(不可替代):执行正确性。 生成的 SQL 在真实数据库上跑一遍,查询结果跟标准答案比对。这是唯一能回答"SQL 写对了没有"的信号。我给的是 F1 软评分——一条返回了 100 行中 99 行正确的 SQL,能拿到 0.99 分,而不是跟完全写错的一样拿 0 分。
代理奖励(辅助平滑梯度):四个维度。
为了让你看懂每个奖励在干什么,我用同一条 prompt 来演示。假设标准答案 SQL 是:
-- prompt:「上个月门诊部开了多少盒二甲双胍?」
-- 标准答案(Ground Truth):
SELECT SUM(ypsl)
FROM mz_cf
WHERE cfrq BETWEEN '2024-05-01' AND '2024-05-31'
AND ypid = 'METFORMIN_500'
模型生成了一条错得千奇百怪的 SQL:
-- 模型输出:
wrong output: SELECT COUNT(*) FROM yf_kc WHERE rksj > '2024-05-01' AND ypmc LIKE '二甲双胍'
现在用四个代理奖励给这条 SQL 打分:
| 代理奖励 | 怎么算 | 这条 SQL 得分 | 为什么 |
|---|---|---|---|
| 格式规范 | 输出是否包含 <reasoning> + <answer> XML 标签 | 0 | 「wrong output:」不能改成 XML,格式不对 |
| 语法正确 | SQL 能不能被 parser 解析(CREATE TABLE 后直接 SELECT 不报错) | 1 | SELECT COUNT(*) FROM yf_kc WHERE ... 语法没问题 |
| Schema 匹配 | 模型用的表名/列名跟标准答案的 Jaccard 相似度 | 0.2 | 标准答案的 token 集 = {mz_cf, ypsl, cfrq, ypid, METFORMIN_500};模型用了 {yf_kc, rksj, ypmc},交集为空,但 WHERE、AND 等 SQL 关键词共同命中,拿到一点分 |
| N-gram 相似度 | 跟标准 SQL 的 bigram(两词序列)重合度 | 0.15 | 标准 SQL 的 bigram = {SELECT SUM, SUM(, ypsl FROM, FROM mz_cf, ...};模型 bigram = {SELECT COUNT, COUNT(, FROM yf_kc, ...},SELECT 和 FROM 相关 bigram 有部分重合 |
最终这条 SQL 的总 reward = 3.0×0(执行正确性)+ 0.1×0 + 0.2×1 + 0.3×0.2 + 0.2×0.15 = 0.29。Ground Truth 权重 3.0 占大头,但因为表全选错了,执行结果跟标准答案对不上=0。代理奖励提供了 0.29 的连续信号,告诉模型「语法对了、但表和字段全选错了,方向需要大调」。
这就是代理奖励的定位:在 Ground Truth 为零的时候,提供连续梯度告诉你错在哪、有多接近。
为什么需要代理奖励?打个比方你就明白了:
想象你在教一个人打靶,但他的靶子蒙着一层布——你只看得到最终结果(靶纸送过来才知道中没中)。 「执行正确性」就是这张靶纸——它是唯一能回答「打没打中」的信号。但问题是:靶纸要等整轮打完之后才能拿到,中间每一枪你不知道偏了多少。
「代理奖励」就是你在旁边看他端枪的姿势。扳机扣动的时机不对(语法错误)——扣分;枪口朝向明显歪了(表名都没选对)——扣分;但枪口方向大致对了、只是手腕微调的问题(表名对了、JOIN 逻辑接近)——加分,告诉他「方向对了,改细节就行」。
没有代理奖励的训练:模型生成了 100 条 SQL,95 条执行后结果不对(reward=0),5 条对了。模型只知道 5 条是好的,剩下 95 条错在哪了?完全不知道——就像靶纸上只有一个「中」或「没中」,没有任何关于偏了多少的信息。
有代理奖励的训练:那 95 条错的 SQL 也能拿到部分分数——「表名选对了(+0.3)但 JOIN 错了」「语法没问题(+0.2)但 WHERE 条件错了」。模型就能知道:「我离正确的方向还差多远,该调哪里。」
这就是两类奖励的关系:Ground Truth 告诉你什么是正确答案,代理奖励告诉你错了的时候错在哪里、有多接近。 只有前者,模型在黑暗中摸索;只有后者,模型在学表面模式而不是真正解题。
权重分配:执行正确性权重 3.0,四个代理奖励各 0.1~0.5。 保证模型优先学"写对 SQL",而不是"写得像标准答案"。
1.2.1 Advantage 和梯度——分数怎么变成学习信号
搞清楚奖励怎么打分之后,下一步是:这些分数怎么变成模型的学习信号?这里有两个核心概念。
Advantage(优势值):GRPO 不关心 0.9 够不够好,只关心 0.9 比组内其他人好多少。每次对同一个 prompt 生成 8 个答案,算出各自的得分后,用公式算每个答案的 advantage:
advantage = (自己的分数 - 8 个答案的平均分) / 8 个答案的标准差
advantage 越大,说明这条答案在组内越突出,模型越应该往这个方向学。advantage 为负,说明这条答案比组内平均还差,模型应该远离它。
梯度(Gradient):模型参数该往哪个方向改、改多大。advantage 告诉你方向(向好答案靠还是远离坏答案),梯度算出力度(这一步迈多大)。
拿一条具体 prompt 的例子走一遍:
Prompt:「上个月门诊部开了多少盒二甲双胍?」
8 个答案的得分:[0.95, 0.30, 0.85, 0.42, 0.78, 0.25, 0.91, 0.38]
均值 = 0.605,标准差 = 0.278
| 答案 | 得分 | advantage | 模型学到什么 |
|---|---|---|---|
| 第 1 个 | 0.95 | (0.95-0.605)/0.278 = +1.24 | 梯度往这个方向推,strong positive |
| 第 7 个 | 0.91 | (0.91-0.605)/0.278 = +1.10 | 也往这个方向推 |
| 第 5 个 | 0.78 | (0.78-0.605)/0.278 = +0.63 | 弱 positive |
| 第 6 个 | 0.25 | (0.25-0.605)/0.278 = -1.28 | 梯度远离这个方向,strong negative |
这就是 GRPO 的工作原理:不依赖绝对分数有多精确,只依赖组内的相对差异。
但这里有一个致命暗坑:组内全部趋同。
训练中后期,模型已经很稳了,8 个答案的得分全是 0.8:
[0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8]
均值 = 0.8,标准差 = 0
advantage = (0.8 - 0.8) / 0 = NaN(除零)
梯度 = 0
标准差为零 → 全部 advantage 等于 NaN → 这一步训练白费。不管 8 个答案质量如何,只要它们得分一样,模型就不会更新。
我的实验里 frac_reward_zero_std 一度飙到 67%。意味着每三步训练就有两步在浪费 8 张 H100 的算力——GRPO 跑了一晚上,其中超过一半的时间在做无用功。
怎么办? 核心思路是强制制造组内方差。解法有两个:
解法一:DAPO 的动态采样。 当检测到 reward 标准差持续过小时,自动增大 num_generations(如 8→16)。样本多了,组内更容易出现差异。
解法二:注入虚拟满分。 这是最便宜也最直接的一个技巧,用数字走一遍你就明白了。
刚才那个死局的场景:8 个答案全是 0.8,标准差为零,梯度 NaN,死锁。
死局:
[0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8]
均值 = 0.8,标准差 = 0
→ advantage 全部 NaN → 这一步白算
虚拟满分的做法:在每组末尾追加一条 reward=1.0 的"假答案"。 这条假答案不参与 loss 计算(它不是真的模型输出),只参与 advantage 计算(让标准差不再是零)。
破局:
[0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 1.0]
↑ 追加的虚拟满分
均值 = 0.822,标准差 = 0.067
→ 每个 0.8 的 advantage = (0.8 - 0.822) / 0.067 = -0.33
加了这条 1.0 之后,标准差从 0 变成了 0.067。原来 8 个 0.8 的答案现在全部算出一个 slight negative advantage(-0.33),模型会往远离这些"平平庸庸的 0.8"的方向微调。
为什么这个 trick 有效? 因为 advantage 的相对性——模型不关心 0.8 够不够好,只关心 0.8 比组内其他人差多少。虚拟满分告诉模型:「你这些 0.8 的答案还不是最好的,还有提升空间。」模型就会继续努力往满分方向靠,而不是因为组内全同就原地不动。
代价是什么? 虚拟满分是一个固定锚点,加了之后所有真实答案的 advantage 都会偏低(因为被 1.0 比下去了)。这会让梯度偏保守——学得慢一点,但不会死锁。在「学得慢」和「完全不学」之间,选慢不选停。
但这里有一个更重要的问题:虚拟满分不会把模型带偏吗?
会,如果你理解错了它的作用方式。虚拟满分只告诉模型「你这些 0.8 的答案还不够好」,但它不告诉模型哪个方向是好方向。模型远离 0.8 不代表它在往 1.0 走——它只是往一个随机的方向挪了一步。
那为什么还用?因为它破的不是「方向问题」,而是「死锁问题」。关键在这里:这一步挪完之后,下一轮训练 Ground Truth 执行奖励又会重新打分、重新算 advantage。 只要 Ground Truth 信号在正常工作,模型挪出去之后,下一轮自然会收到真实的反馈——「这个方向偏了,改回来」或「方向对了,继续」。虚拟满分只是帮模型从「动不了」变成「能动」,后续的方向纠偏靠的是 Ground Truth。
什么时候不能用? 这就是第一个坑的教训——如果你的 Ground Truth 执行奖励没有正常工作(比如数据库路径配错了),虚拟满分会让模型在错误的方向上随机游走,越训越差。它只能在「Ground Truth 正常、只是组内方差消失」的场景下用。本质上,它是给已经会解题的模型注射一针兴奋剂让它继续练,不是给还不会解题的模型教它怎么解题。
1.3 DAPO 是什么
DAPO(Dynamic Sampling Policy Optimization)是 GRPO 的改进变体。要理解它解决了什么问题,先看一个具体的数字例子。
GRPO 每次训练,对同一个 prompt 生成 N 个答案,然后比较它们的优劣来算梯度。问题是:如果 N 太小,生成的答案可能全对或全错——组内没差异,这一步训练白费。
举个实际例子。假设模型在某条 prompt 上的正确率是 60%:
| num_generations | 全对概率 | 全错概率 | 「白费」概率 |
|---|---|---|---|
| 4 | 0.6⁴ = 13.0% | 0.4⁴ = 2.6% | 15.6% |
| 8 | 0.6⁸ = 1.7% | 0.4⁸ = 0.07% | 1.7% |
换成白话:num_generations=4 时,每训练 6-7 步就有 1 步白算——8 张 H100 跑一个小时,其中 10 分钟在算空气。num_generations=8 时,这个概率降到 1.7%,基本不浪费。
DAPO 除了增大采样数,还做了两个关键改进:
- 动态采样策略:根据当前 reward 分布自动调整采样数——模型还不稳时多采样(避免运气成分),稳了之后少采样(省算力)。
- 更强的 KL 散度约束:训练过程中不断检查「模型偏离原始策略多远」,偏离超过阈值就拉回来。这是为了防止模型为了刷高 reward 而写出格式对、语义错的 SQL(第三个坑里会详细讲)。
我最终选 DAPO + cosine LR + KL 约束 β=0.01 作为基础算法。
1.4 RAG Agent 是什么
RAG(Retrieval-Augmented Generation)= 先检索相关信息,再让模型基于这些信息生成答案。
这个项目里模型需要知道"这个问题该查哪些表和字段"。但几十个数据库、上百张表的 schema 拼起来远超上下文窗口。RAG Agent 的做法是:先把用户问题做一次"业务路由"(判断属于门诊库还是药房库),然后在目标数据库内检索最相关的表和字段,只把这些相关 schema 塞进 prompt。
用一条具体查询走一遍,你就知道它为什么有必要了。 假设分析师问:
「上个月门诊部开了多少盒头孢?」
这条问题背后涉及的数据库结构是这样的:
| 候选数据库 | 候选表 | 相关字段 |
|---|---|---|
| 门诊库 (30 张表) | mz_cf(处方表)、mz_jz(就诊表)、mz_yp(药品字典)等 | cfrq(处方日期)、ypmc(药品名称)、ypsl(药品数量) |
| 药房库 (25 张表) | yf_kc(库存表)、yf_rk(入库表)、yf_ck(出库表)等 | rksj(入库时间)、cksj(出库时间) |
| 住院库 (40 张表) | zy_yz(医嘱表)、zy_br(患者表)等 | — |
如果眉毛胡子一把抓,把 95 张表的 schema 全塞进 prompt:
- 首先,token 数直接爆上下文窗口。
- 其次,模型会被一堆无关信息干扰——"药房库存"、"住院医嘱"跟「开了多少盒头孢」毫无关系。
RAG Agent 做的事:先识别「门诊」关键词 + embedding 相似度 → 锁定门诊库 → 在门诊库 30 张表里检索跟「头孢」「开药」最相关的 3-5 张表 → 最终塞进 prompt 的只有 3-5 张表的 schema + 3 条历史类似 SQL 示例 + 原始问题。上下文从「整个医院」缩小到「门诊开药」。
二、任务定义
- 场景:医疗数据分析 Text2SQL。客户有几十个数据库,表名拼音缩写、字段命名不规范、JOIN 逻辑是业务特有的
- 约束:包含患者隐私数据的 schema 不能出内网 → 必须本地部署开源模型
- 数据:有历史 SQL 日志和标准执行结果,但没有 NL→SQL 配对标注
- 模型:Qwen2.5-Coder-32B-Instruct,4bit 量化 + LoRA,8 张 H100
三、三个差点没爬出来的坑
第一个坑:数据库路径配错了——RL 在没有真实反馈的情况下跑了整晚
跑了几种 GRPO 变体同时对比(Dr.GRPO、DAPO、REINFORCE++),扔到 8 卡上跑了一晚上。早上看 metrics 曲线:schema 匹配度从 0.50 涨到 0.65,ngram 相似度也在涨。 一切看起来在变好。
然后我发现一个事实:执行正确性始终是 0。
查了一个小时才发现——训练脚本里数据库路径配错了,SQL 根本没有真正执行过。模型收到的所有反馈都来自代理奖励。schema 匹配度涨了,只是因为模型学会了一种模式匹配——"看到这种问题,就选这几张表"。
这不叫学会写 SQL,这叫学会了猜表名。
就像学游泳的人在陆地上对着视频练动作——姿势看起来像了,但一下水就沉。模型学到的不是 SQL 的语义正确性,而是「看到这种问题描述,就挑这几张表」的概率映射。
具体地说,训练了一个晚上之后,模型面对「上个月门诊部开了多少盒二甲双胍」这条 prompt,输出是这样的:
-- 训练后的模型输出(执行正确性始终为 0 的情况下)
SELECT SUM(ypsl)
FROM mz_cf
WHERE cfrq BETWEEN '2024-05-01' AND '2024-05-31'
AND ypid = 'METFORMIN_500'
看着是不是很对?表名 mz_cf 对了,日期过滤对了,药品 ID 也有了。但这一条 SQL 从来没能真正在数据库上跑过——表名对了只是因为模型学会了「看到"门诊"就输出 mz_cf」的模式匹配。一旦问题稍微变一下,比如「上个月门诊部哪些患者同时开了二甲双胍和胰岛素」,模型就露馅了:
-- 稍微变一下 prompt,模型就写出完全不靠谱的 JOIN
SELECT p.name
FROM mz_cf c
JOIN patient p ON c.ypid = p.name
WHERE c.cfrq BETWEEN '2024-05-01' AND '2024-05-31'
AND c.ypid = 'METFORMIN_500' AND c.ypid = 'INSULIN_GLARGINE'
JOIN ON c.ypid = p.name 把药品 ID 跟患者姓名 JOIN——语法对、表都用对了、但逻辑完全不通。更致命的是 ypid = 'METFORMIN_500' AND ypid = 'INSULIN_GLARGINE'——一个字段同时等于两个不同值,这条 SQL 永远不会返回行。这就是只有代理奖励没有 Ground Truth 的下场:模型学会了用正确的积木,但不知道怎么搭。
这个坑直接验证了奖励函数设计的一个关键原则:代理奖励能引导方向,但不能替代 Ground Truth。 没有执行奖励的 RL 训练,本质上退化成了低效版 SFT——在模仿标准答案的表面特征,不是在学语义正确的 SQL。
教训:开 RL 训练前,先跑一轮 smoke test——随机挑 10 条数据,确认 Ground Truth 奖励信号在正常工作。 一个路径错误浪费了一晚上 8×H100 的算力。
第二个坑:一条笛卡尔积 SQL 搞崩了整个 8 卡训练
修好数据库路径后重新训练。execution reward 终于从 0 开始涨了。
跑到 22%,整个 8 卡训练突然挂了。 没有任何预警,所有进程一起 SIGABRT。
查日志看到这行:
WorkNCCL(SeqNum=7034, OpType=_ALLGATHER_BASE)
ran for 1800063ms before timing out
翻译:某个 GPU 在一个同步操作上卡了 30 分钟。其他 7 张卡算完了在等它,集体超时退出。
根因:模型生成了一条笛卡尔积 SQL:
-- 这条 SQL 就是崩掉 8 张 H100 的罪魁祸首
SELECT *
FROM mz_cf c, mz_jz j, mz_yp y
WHERE c.cfrq BETWEEN '2024-05-01' AND '2024-05-31'
三张大表交叉连接(CROSS JOIN)——处方表 mz_cf 10 万行 × 就诊表 mz_jz 50 万行 × 药品字典 mz_yp 8000 行 = 4000 亿行的中间结果集。SQLite 老老实实地在算这张卡上的笛卡尔积,直到其他 7 张卡的 NCCL watchdog 判定超时。
DDP 训练的致命弱点:一个 rank 卡住 = 全体卡住。 这不是 GRPO 的 bug,这是分布式训练的基础约束——ALL_GATHER 需要所有 rank 的数据到齐才能继续。就像 8 个人一起搬桌子,7 个人搬起来了,1 个人不动——桌子就是抬不上去。一个 rank 在 IO 里卡 30 分钟,另外 7 个就只能等 30 分钟。
作为一个 Text2SQL 新手,修了三次才搞定:
第一次:signal.alarm(30) 设超时。 失败了。SIGALRM 只能发给主线程,DDP 的 worker 线程收不到信号,SQL 执行线程完全不受影响。
第二次:threading.Timer。 又失败了。SQLite 连接不是线程安全的,跨线程关闭连接会导致 segfault——没有优雅降级,直接进程崩溃。
第三次:ThreadPoolExecutor。 终于稳了。
from concurrent.futures import ThreadPoolExecutor, TimeoutError
executor = ThreadPoolExecutor(max_workers=1)
def execute_with_timeout(db_path, sql, timeout=30):
try:
future = executor.submit(run_sql, db_path, sql)
result = future.result(timeout=timeout)
return result
except TimeoutError:
return None # reward = 0,这条 SQL 不参与训练
把每次 SQL 执行扔到独立线程,用 .result(timeout=30) 等结果。超时就认栽,返回 None,reward=0,不会拖死其他 rank。
教训:RL 训练里跑外部程序(SQL 执行、代码运行、API 调用),一定要加超时。尤其是 DDP 多卡环境,一个 rank 的阻塞会传染给集群里所有 rank。
第三个坑:算法没选对,KL 散度爆炸了 38 倍
GRPO 有几个有名的变体,我一开始无脑全试了,想对比哪个最好。
Dr.GRPO:学得飞快,但代价是 KL 爆炸。
在讲这个坑之前,先说一下 KL 散度是什么意思。KL 散度衡量的是「现在的模型跟原始模型变得有多不一样」。 你可以把它想象成一条橡皮筋——原始模型在橡皮筋的一端,训练过程中模型越学越远,橡皮筋就拉得越紧。0.00028 是训练刚开始时的松弛状态,0.0108 是橡皮筋被拉到接近崩断的状态。一旦崩断,模型就开始胡言乱语——语法看起来对,但内容已经脱离逻辑。
Dr.GRPO 去掉了组内标准差归一化——相当于给梯度加了杠杆。这就像开车:标准 GRPO 踩油门是按固定力度,Dr.GRPO 是每次把油门踩的力度乘以当前速度——前期开得确实快,reward 蹭蹭涨。但到了中后期,KL 散度从 0.00028 飙到了 0.0108,涨了 38 倍。
这时候模型开始出现诡异行为——生成一些语法正确但语义荒谬的 SQL。比如把患者年龄跟药品库存做 JOIN:
-- 这条语法完全正确,但在逻辑上是扯淡
SELECT AVG(p.age), SUM(s.stock_qty)
FROM patient p
JOIN drug_stock s ON p.age = s.stock_qty
为什么会写出这种东西?因为模型为了刷高 reward,学会了钻规则的空子——语法检查满分、格式规范满分、甚至 N-gram 跟标准答案相似度也在涨。但语义已经完全脱离了数据库的真实逻辑关系。
就像学生摸清了选择题的出题规律:虽然不知道知识点,但能靠「三长一短选最短」拿到高分。 —— KL 爆炸的本质就是模型在学应试技巧,而不是学解题能力。
REINFORCE++:稳如老狗,但也学不到东西。
跟 Dr.GRPO 相反,REINFORCE++ 极其保守。KL 散度几乎不变,reward 也纹丝不动。就像开着手刹踩油门——太稳了,稳到完全没进步。
最后选了 DAPO + 三个关键调整:
| 调整 | 做了什么 | 为什么 |
|---|---|---|
| Cosine LR schedule | 前期正常学,后期自动降低学习率 | 防止后期策略漂移失控——学得差不多了就降速 |
| KL 约束 β=0.01 | 策略偏离超阈值时,在 loss 里加惩罚项 | 相当于给橡皮筋加了安全锁——拉到一定程度就自动拉回来 |
num_generations 4→8 | 每个 prompt 生成 8 个答案做组内比较 | 组内全对或全错的概率从 15.5% 降到 1.7% |
这个组合兼顾了学习速度和稳定性:Dr.GRPO 是有进度但失控,REINFORCE++ 是稳但没进度,DAPO 在中间——进度看得见,橡皮筋拉不断。
教训:算法选择要平衡激进与保守。 学得快的容易崩(Dr.GRPO),太稳的学不动(REINFORCE++)。DAPO + cosine LR + KL 约束是一个不错的起点——它是中间地带,有进度但不失控。
四、推理侧:两层 RAG Agent
模型训好了,但实际推理时还有一个问题:客户有几十个数据库、上百张表,模型需要先知道"这个问题该查哪些表"才能写 SQL。把所有表的 schema 全部塞进 prompt 不可能——几十个库的 schema 拼起来远超上下文窗口。
我加了两层 RAG Agent。用一条具体查询走一遍完整流程:
用户问:「上个月门诊部头孢类药品的日均消耗量是多少?」
第一层:业务路由。 把用户问题跟业务领域做匹配。
| 匹配方式 | 做了什么 | 结果 |
|---|---|---|
| 关键词匹配 | "门诊部" → 门诊库 | 门诊库得分最高 |
| embedding 相似度 | 整句跟各库描述向量算余弦相似度 | 门诊库 0.91 > 药房库 0.34 > 住院库 0.12 |
这一步把 3 个候选库缩小到 1 个:门诊库。
第二层:Schema 检索。 在门诊库内用问题检索最相关的表和字段。
| 检索步骤 | 具体操作 |
|---|---|
| 表检索 | 用「头孢」「消耗量」「日」在门诊库 30 张表中召回 top-5 张:mz_cf(处方表)、mz_yp(药品字典)、mz_jz(就诊表)、mz_gh(挂号表)、mz_sf(收费表) |
| 字段检索 | 从上一步 5 张表的字段中筛选出最相关的 15-20 个字段 |
| few-shot 检索 | 从历史 SQL 日志中找出 3 条语义相似的查询作为示例 |
最终塞进 prompt 的内容只有三样:
1. 相关 schema 定义(5 张表的关键字段)
- mz_cf: cfrq(处方日期), ypid(药品ID), ypsl(数量)
- mz_yp: ypid(药品ID), ypmc(药品名称), yplb(药品类别)
- mz_jz: jzrq(就诊日期), patient_id, ksdm(科室代码)
...
2. 3 条 few-shot SQL 示例
- "上月门诊抗生素处方总量" → SELECT SUM(c.ypsl) FROM mz_cf c JOIN mz_yp y...
- "本周门诊接诊人次" → SELECT COUNT(*) FROM mz_jz WHERE jzrq...
3. 用户的自然语言问题
「上个月门诊部头孢类药品的日均消耗量是多少?」
效果上,RAG Agent 对复杂查询的提升最明显——简单查询模型 zero-shot 就能搞定,但涉及多表 JOIN、子查询的复杂场景,few-shot 示例引导着模型走通关键 JOIN 路径,效果天差地别。
五、最终结果
训练前(Base Qwen2.5-Coder-32B + RAG):
| 难度 | few-shot 准确率 |
|---|---|
| 简单 | 78% |
| 复杂 | 40% |
训练后(GRPO + RAG):
| 难度 | zero-shot | few-shot |
|---|---|---|
| 简单 | 87% | 95% |
| 复杂 | 62% | 87% |
几个值得注意的对比:
- zero-shot 87% 超过了训练前的 few-shot 78%。 GRPO 确实让模型学到了东西,而不只是靠 few-shot 示例撑着——模型内部的 SQL 生成能力真的变强了
- 复杂查询提升最大:40%→87%(+47 点)。 多表 JOIN 和嵌套子查询是 RL 训练收益最大的地方——规则奖励函数对这些场景区分度太差,Ground Truth 执行奖励反而精准
- few-shot 在训练后仍有 ~8% 的额外提升。 RAG Agent 和模型能力是互补的——模型本身写 SQL 强了,给几个参照示例还能再往上提一截
看一条复杂查询的训练前后对比,就能直观感受 +47 点提升长什么样。
用户问:「上个月门诊部开了降糖药的患者中,有多少人同时挂了内分泌科的号?」
这条查询涉及三表 JOIN(处方表 mz_cf + 药品字典 mz_yp + 挂号表 mz_gh),子查询过滤降糖药分类,再加 GROUP BY 去重。训练前后同一道题的答案是这样的:
-- 训练前(Base 模型,带了 few-shot 示例,仍然写错了)
SELECT COUNT(*)
FROM mz_cf c
JOIN mz_yp y ON c.ypid = y.ypid
WHERE c.cfrq BETWEEN '2024-05-01' AND '2024-05-31'
AND y.yplb LIKE '%降糖%'
AND patient_id IN (
SELECT patient_id FROM mz_gh WHERE ghks = '内分泌科'
)
-- 问题:挂号日期没限定,患者上个月挂了内分泌科但不一定是上个月开的降糖药
-- 训练后(GRPO + RAG,zero-shot,正确)
SELECT COUNT(DISTINCT c.patient_id)
FROM mz_cf c
JOIN mz_yp y ON c.ypid = y.ypid
JOIN mz_gh g ON c.patient_id = g.patient_id AND c.cfrq = g.ghrq
WHERE c.cfrq BETWEEN '2024-05-01' AND '2024-05-31'
AND y.yplb LIKE '%降糖%'
AND g.ghks = '内分泌科'
训练前模型把挂号跟处方分了两步子查询——看起来"有逻辑",但实际上没对齐日期,会把上个月开过降糖药但三个月前挂过内分泌科的人也数进去。训练后模型直接做了三表 JOIN 并显式对齐 c.cfrq = g.ghrq,把两件事锁定在同一天——这才是业务上正确的语义。
Ground Truth 执行奖励就是在这种场景里发力的:子查询的 SQL 执行后也能返回行(不会报错),但返回的行比正确答案多了——F1 软评分会捕捉到「结果接近但不精确」,扣掉分数,告诉模型「方向对但精确度不够」。
六、Insight 提炼
6.1 RL 的核心价值是 Ground Truth 奖励——代理奖励只是辅助
第一个坑是最贵的教训:数据库路径配错,代理奖励撑了整晚的"假训练"。Schema 匹配度涨了,但模型学会的是"看到问题 A 就选表 B"的模式匹配,不是真正的 SQL 生成能力。
代理奖励的作用是"在 ground truth 稀疏时提供连续梯度引导方向"。但如果只有代理奖励没有 ground truth,RL 就退化成了低效版 SFT——在模仿标准答案的表面特征。
开 RL 训练前必须做的 check:随机挑 10 条数据跑一轮,确认 Ground Truth 奖励信号在正常产生非零值。
6.2 DDP 环境下,一个 rank 的阻塞会传染全部
第二个坑是纯工程问题,但杀伤力极大——一条笛卡尔积 SQL 不需要有任何语法错误,就可以在 30 分钟内让 8 张 H100 全部停摆。
signal.alarm 只能发主线程、threading.Timer 会 segfault、最终 ThreadPoolExecutor + .result(timeout=30) 才是稳定方案。这个坑的教训可以推广到任何 RL 训练里跑外部程序(代码执行、API 调用、shell 命令)的场景——永远加超时,永远做好优雅降级。
6.3 算法选择要平衡激进与保守——DAPO 是中间地带
Dr.GRPO 学得飞快但 KL 爆炸 38 倍,REINFORCE++ 太保守学不动。DAPO + cosine LR + KL 约束 β=0.01 是性价比最高的组合。多一个维度:num_generations=8 把"组内全对或全错"的概率从 15.5% 压到 1.7%——组内无差异=这步训练白费,虽然看起来像调参细节,但对训练效率影响很大。
6.4 GRPO 的选择标准:什么时候该用
这个项目恰好展示了 GRPO 最适用的一类场景:当你没有标注数据,但有一个可以自动验证结果的"标准答案源"时。
Text2SQL 有数据库可以执行 SQL 并比对结果;代码生成有测试用例;数学推理有标准答案。这些场景的共同特点是:Ground Truth 可以通过程序自动获得,不需要人工标注。GRPO 就是为这类场景设计的——让模型自己试,试对了加分,试错了扣分。
如果你的场景没有自动验证的手段,RL 的 ground truth 奖励就得靠人工标注来提供——这时候 GRPO 的成本优势就没了,不如直接用 SFT。