基于 Python 的自然语言处理系列(84):SFT(Supervised Fine-Tuning)原理与实现
SFT 是 RLHF 训练管线中的基石步骤,为后续的奖励建模(RM)和强化学习(PPO)打下良好基础。模式名称数据类型特点基础微调IMDb(纯文本)适合情感分析、语言建模指令微调CodeAlpaca(问答对)对齐生成任务,如问答、摘要Alpaca 样式微调更复杂结构,适用于多任务下一篇我们将进入 RLHF 管线的收尾阶段——。欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉
📌 官方文档:Supervised Fine-tuning in TRL
💻 示例代码:sft.py
🧩 SFT 是 RLHF 中的第一阶段,为训练指令响应式语言模型奠定基础
一、SFT 是什么?
SFT(Supervised Fine-tuning)即有监督微调,通常是 RLHF 流水线中的第一步,其目标是让预训练语言模型学会更好地响应任务指令。其核心流程如下:
-
加载预训练语言模型(如 GPT2)
-
准备含有“指令–响应”格式的数据集
-
使用
trl
库中的SFTTrainer
接口对模型进行微调
本篇内容将涵盖:
-
IMDb 情感数据上的基本 SFT 微调
-
使用 CodeAlpaca 进行指令微调(Instruction Tuning)
-
使用 HuggingFaceH4 数据构建 Alpaca 风格微调任务
二、环境准备
pip install peft==0.7.1 trl==0.7.4 transformers==4.36.2
import os
import torch
import transformers
import trl
# 设置代理与设备
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ['http_proxy'] = 'http://192.41.170.23:3128'
os.environ['https_proxy'] = 'http://192.41.170.23:3128'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
三、IMDb 基础 SFT 微调流程
Step 1:加载 IMDb 数据集
from datasets import load_dataset
# 0 表示负面,1 表示正面
dataset = load_dataset("imdb", split="train")
print(dataset[0])
Step 2:加载模型与 tokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "distilgpt2"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
max_seq_length = min(tokenizer.model_max_length, 1024)
Step 3:定义训练器并启动微调
from transformers import TrainingArguments
from trl import SFTTrainer
training_args = TrainingArguments(
output_dir="tmp_trainer",
num_train_epochs=5,
)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset.select(range(1000)),
dataset_text_field="text",
max_seq_length=max_seq_length,
)
trainer.train()
四、CodeAlpaca 数据集上的指令微调(Instruction Tuning)
Step 1:加载 CodeAlpaca 数据集
dataset = load_dataset("lucasmccabe-lmi/CodeAlpaca-20k", split="train")
print(dataset[20000])
Step 2:构造格式化函数与 collator
def formatting_prompts_func(example):
output_texts = []
for i in range(len(example["instruction"])):
text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}"
output_texts.append(text)
return output_texts
from trl import DataCollatorForCompletionOnlyLM
response_template = " ### Answer:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
Step 3:训练器设置与训练
trainer = SFTTrainer(
model,
train_dataset=dataset.select(range(1000)),
formatting_func=formatting_prompts_func,
data_collator=collator,
)
trainer.train()
五、标准 Alpaca 格式化微调(Instruction + Input)
Step 1:加载 HuggingFaceH4/instruction-dataset
dataset = load_dataset("HuggingFaceH4/instruction-dataset")
dataset = dataset.remove_columns("meta")
Step 2:构建 Alpaca 格式样式
def format_instruction(sample):
return f"""
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{sample['prompt']}
### Response:
{sample['completion']}
""".strip()
Step 3:训练模型
model = AutoModelForCausalLM.from_pretrained("distilgpt2", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
tokenizer.pad_token = tokenizer.eos_token
trainer = SFTTrainer(
model,
train_dataset=dataset["test"],
tokenizer=tokenizer,
max_seq_length=1024,
formatting_func=format_instruction,
)
trainer.train()
六、总结
SFT 是 RLHF 训练管线中的基石步骤,为后续的奖励建模(RM)和强化学习(PPO)打下良好基础。本篇我们介绍了以下几种微调模式:
模式名称 | 数据类型 | 特点 |
---|---|---|
基础微调 | IMDb(纯文本) | 适合情感分析、语言建模 |
指令微调 | CodeAlpaca(问答对) | 对齐生成任务,如问答、摘要 |
Alpaca 样式微调 | HF Instruction Dataset | 更复杂结构,适用于多任务 |
下一篇我们将进入 RLHF 管线的收尾阶段——PPO(Direct Preference Optimization)。
如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!
欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。
谢谢大家的支持!

GitCode 天启AI是一款由 GitCode 团队打造的智能助手,基于先进的LLM(大语言模型)与多智能体 Agent 技术构建,致力于为用户提供高效、智能、多模态的创作与开发支持。它不仅支持自然语言对话,还具备处理文件、生成 PPT、撰写分析报告、开发 Web 应用等多项能力,真正做到“一句话,让 Al帮你完成复杂任务”。
更多推荐
所有评论(0)