📌 官方文档:Supervised Fine-tuning in TRL
💻 示例代码:sft.py
🧩 SFT 是 RLHF 中的第一阶段,为训练指令响应式语言模型奠定基础

一、SFT 是什么?

        SFT(Supervised Fine-tuning)即有监督微调,通常是 RLHF 流水线中的第一步,其目标是让预训练语言模型学会更好地响应任务指令。其核心流程如下:

  1. 加载预训练语言模型(如 GPT2)

  2. 准备含有“指令–响应”格式的数据集

  3. 使用 trl 库中的 SFTTrainer 接口对模型进行微调

        本篇内容将涵盖:

  • IMDb 情感数据上的基本 SFT 微调

  • 使用 CodeAlpaca 进行指令微调(Instruction Tuning)

  • 使用 HuggingFaceH4 数据构建 Alpaca 风格微调任务

二、环境准备

pip install peft==0.7.1 trl==0.7.4 transformers==4.36.2
import os
import torch
import transformers
import trl

# 设置代理与设备
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ['http_proxy']  = 'http://192.41.170.23:3128'
os.environ['https_proxy'] = 'http://192.41.170.23:3128'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

三、IMDb 基础 SFT 微调流程

 Step 1:加载 IMDb 数据集

from datasets import load_dataset

# 0 表示负面,1 表示正面
dataset = load_dataset("imdb", split="train")
print(dataset[0])

Step 2:加载模型与 tokenizer

from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "distilgpt2"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
max_seq_length = min(tokenizer.model_max_length, 1024)

Step 3:定义训练器并启动微调

from transformers import TrainingArguments
from trl import SFTTrainer

training_args = TrainingArguments(
    output_dir="tmp_trainer",
    num_train_epochs=5,
)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset.select(range(1000)),
    dataset_text_field="text",
    max_seq_length=max_seq_length,
)

trainer.train()

四、CodeAlpaca 数据集上的指令微调(Instruction Tuning)

Step 1:加载 CodeAlpaca 数据集

dataset = load_dataset("lucasmccabe-lmi/CodeAlpaca-20k", split="train")
print(dataset[20000])

Step 2:构造格式化函数与 collator

def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example["instruction"])):
        text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}"
        output_texts.append(text)
    return output_texts

from trl import DataCollatorForCompletionOnlyLM

response_template = " ### Answer:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

Step 3:训练器设置与训练

trainer = SFTTrainer(
    model,
    train_dataset=dataset.select(range(1000)),
    formatting_func=formatting_prompts_func,
    data_collator=collator,
)

trainer.train()

五、标准 Alpaca 格式化微调(Instruction + Input)

Step 1:加载 HuggingFaceH4/instruction-dataset

dataset = load_dataset("HuggingFaceH4/instruction-dataset")
dataset = dataset.remove_columns("meta")

Step 2:构建 Alpaca 格式样式

def format_instruction(sample):
    return f"""
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{sample['prompt']}

### Response:
{sample['completion']}
""".strip()

Step 3:训练模型

model = AutoModelForCausalLM.from_pretrained("distilgpt2", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
tokenizer.pad_token = tokenizer.eos_token

trainer = SFTTrainer(
    model,
    train_dataset=dataset["test"],
    tokenizer=tokenizer,
    max_seq_length=1024,
    formatting_func=format_instruction,
)

trainer.train()

六、总结

        SFT 是 RLHF 训练管线中的基石步骤,为后续的奖励建模(RM)和强化学习(PPO)打下良好基础。本篇我们介绍了以下几种微调模式:

模式名称 数据类型 特点
基础微调 IMDb(纯文本) 适合情感分析、语言建模
指令微调 CodeAlpaca(问答对) 对齐生成任务,如问答、摘要
Alpaca 样式微调 HF Instruction Dataset 更复杂结构,适用于多任务

        下一篇我们将进入 RLHF 管线的收尾阶段——PPO(Direct Preference Optimization)

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

Logo

GitCode 天启AI是一款由 GitCode 团队打造的智能助手,基于先进的LLM(大语言模型)与多智能体 Agent 技术构建,致力于为用户提供高效、智能、多模态的创作与开发支持。它不仅支持自然语言对话,还具备处理文件、生成 PPT、撰写分析报告、开发 Web 应用等多项能力,真正做到“一句话,让 Al帮你完成复杂任务”。

更多推荐