4090显卡,通过unsloth对DeepSeek-R1-Distill-Llama-8B模型进行推理和CoT微调训练
参考B站“九天Hector”的DeepSeek R1微调实战视频的介绍,在4090 GPU上,测试了DeepSeek-R1-Distill模型的推理和微调训练。到wandb上注册一个账户,生成key,例如:key:0123456789。"YOUR_WANDB_API_KEY"是之前申请的wandb key。question = "请问如何证明根号2是无理数?4090关闭P2P和InfiniBand
参考B站“九天Hector”的DeepSeek R1微调实战视频的介绍,在4090 GPU上,测试了DeepSeek-R1-Distill模型的推理和微调训练。所有步骤如下:
0、创建unsloth conda环境
conda create --name unsloth_env \
python=3.11 \
pytorch-cuda=12.1 \
pytorch cudatoolkit xformers -c pytorch -c nvidia -c xformers \
-y
conda activate unsloth_env
1、安装unsloth
pip install unsloth -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/
2、安装wandb用于监测模型训练进展(可选)
到wandb上注册一个账户,生成key,例如:key:0123456789
pip install -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/ wandb
3、下载8B模型
pip install -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/ modelscope
mkdir DeepSeek-R1-Distill-Llama-8B
modelscope download --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B --local_dir ./DeepSeek-R1-Distill-Llama-8B
4、下载数据集
mkdir medical-o1-reasoning-SFT
modelscope download --dataset AI-ModelScope/medical-o1-reasoning-SFT --local_dir ./medical-o1-reasoning-SFT
5、模型推理测试
python
from unsloth import FastLanguageModel
max_seq_length = 2048
dtype = None
load_in_4bit = False
model, tokenizer = FastLanguageModel.from_pretrained(model_name = "./DeepSeek-R1-Distill-Llama-8B",max_seq_length = max_seq_length,dtype = dtype,load_in_4bit = load_in_4bit)
FastLanguageModel.for_inference(model)
question = "请问如何证明根号2是无理数?"
inputs = tokenizer([question], return_tensors="pt").to("cuda")
outputs = model.generate(input_ids=inputs.input_ids, max_new_tokens=1200, use_cache=True)
response = tokenizer.batch_decode(outputs)
print(response[0])
6、微调测试
6.0、处理环境变量、安装数据集包
4090关闭P2P和InfiniBand
export NCCL_P2P_DISABLE=1
export NCCL_IB_DISABLE=1
pip install datasets
python
6.1、加载模型
from unsloth import FastLanguageModel
max_seq_length = 2048
dtype = None
load_in_4bit = False
model, tokenizer = FastLanguageModel.from_pretrained(model_name = "./DeepSeek-R1-Distill-Llama-8B",max_seq_length = max_seq_length,dtype = dtype,load_in_4bit = load_in_4bit)
6.2、处理数据集
import os
from datasets import load_dataset
train_prompt_style = """Below is an instruction that describes a task, paired with an input that provides further context.
Write a response that appropriately completes the request.
Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response.
### Instruction:
You are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning.
Please answer the following medical question.
### Question:
{}
### Response:
<think>
{}
</think>
{}"""
EOS_TOKEN = tokenizer.eos_token
def formatting_prompts_func(examples):
inputs = examples["Question"]
cots = examples["Complex_CoT"]
outputs = examples["Response"]
texts = []
for input, cot, output in zip(inputs, cots, outputs):
text = train_prompt_style.format(input, cot, output) + EOS_TOKEN
texts.append(text)
return {
"text": texts,
}
dataset = load_dataset("./medical-o1-reasoning-SFT","en", split = "train[0:500]",trust_remote_code=True)
dataset = dataset.map(formatting_prompts_func, batched = True,)
6.3、设置模型为微调模式
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context
random_state=3407,
use_rslora=False,
loftq_config=None,
)
6.4、创建有监督微调对象
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=max_seq_length,
dataset_num_proc=2,
args=TrainingArguments(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
# Use num_train_epochs = 1, warmup_ratio for full training runs!
warmup_steps=5,
max_steps=60,
learning_rate=2e-4,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=10,
optim="adamw_8bit",
weight_decay=0.01,
lr_scheduler_type="linear",
seed=3407,
output_dir="outputs",
),
)
6.5、设置wandb(可选)
import wandb
wandb.login(key="YOUR_WANDB_API_KEY")
"YOUR_WANDB_API_KEY"是之前申请的wandb key。
6.6、开始微调
trainer_stats = trainer.train()
print(trainer_stats)
6.7、测试微调后的模型
FastLanguageModel.for_inference(model)
question_1 = "A 61-year-old woman with a long history of involuntary urine loss during activities like coughing or sneezing but no leakage at night undergoes a gynecological exam and Q-tip test. Based on these findings, what would cystometry most likely reveal about her residual volume and detrusor contractions?"
question_2 = "Given a patient who experiences sudden-onset chest pain radiating to the neck and left arm, with a past medical history of hypercholesterolemia and coronary artery disease, elevated troponin I levels, and tachycardia, what is the most likely coronary artery involved based on this presentation?"
prompt_style = """Below is an instruction that describes a task, paired with an input that provides further context.
Write a response that appropriately completes the request.
Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response.
### Instruction:
You are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning.
Please answer the following medical question.
### Question:
{}
### Response:
<think>{}"""
inputs = tokenizer([prompt_style.format(question_1, "")], return_tensors="pt").to("cuda")
outputs = model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=1200,
use_cache=True,
)
response = tokenizer.batch_decode(outputs)
print(response[0].split("### Response:")[1])
6.8、模型权重合并
new_model_local = "DeepSeek-R1-Medical-COT-Tiny"
model.save_pretrained(new_model_local)
tokenizer.save_pretrained(new_model_local)
model.save_pretrained_merged(new_model_local, tokenizer, save_method = "merged_16bit",)

GitCode 天启AI是一款由 GitCode 团队打造的智能助手,基于先进的LLM(大语言模型)与多智能体 Agent 技术构建,致力于为用户提供高效、智能、多模态的创作与开发支持。它不仅支持自然语言对话,还具备处理文件、生成 PPT、撰写分析报告、开发 Web 应用等多项能力,真正做到“一句话,让 Al帮你完成复杂任务”。
更多推荐
所有评论(0)