
用Unsloth微调一个老中医大模型
本文介绍了如何使用Unsloth框架微调大语言模型,以《伤寒论》数据集为例训练一个中医专家模型。Unsloth显著降低了微调的资源需求。文章涵盖了从环境配置、模型选择、数据准备到训练部署的完整流程,为垂直领域模型微调提供了实用参考。
1. 关于Unsloth
2. Unsloth 的核心优势
3. 使用要求
4. 安装 Unsloth
5. 选择模型
6. 数据集准备
7. 开始微调
7.1 引入依赖
7.2 加载模型
7.3 加载数据集
7.4 定义 LoRA
7.5 使用 `SFTTrainer` 进行训练
7.6 模型保存
7.7 训练过程
8. 测试微调的模型
在实际应用中,我们常常面临特定场景下的问题需求。此时,通过指令微调可以让模型更好地适应这些具体任务,从而提升模型的实用性和表现。
比如:在医疗健康领域,医生希望让大模型更好地理解中医诊断和治疗的知识体系,从而能够辅助分析病情、推荐中药方剂,甚至自动生成病历摘要。又如,在法律行业,律师团队希望通过微调,让大模型能够更准确地解读中国法律条文、判例和合同文本,辅助法律检索和文书生成。此外,在教育领域,教师可以通过指令微调,让模型更贴合本地教材内容,自动批改作业、生成个性化学习建议。这些场景都需要针对特定任务和数据对大模型进行定制化微调,以提升其在实际应用中的表现和价值。
正好最近身体不适,并且得到一本医疗秘籍《伤寒论》,用 Unsloth
来微调一个垂直领域的模型出来。
完整代码,在公众号「AI取经路」发消息「微调」获取
1. 关于Unsloth
大型语言模型(LLM)的微调有很大的资源挑战,如高昂的内存需求和漫长的训练时间。
传统的微调方法,可能需要数十小时才能完成一项任务,并且常常导致内存不足问题 。这种特性限制了个人开发者和小型团队对 LLM 进行定制化和优化的能力。
Unsloth 正是为了解决这些痛点而诞生的。
它是一个专门为加速 LLM 微调并显著降低内存消耗而设计的 Python 框架 。
Unsloth 实现了显著的性能提升,同时保持了与 Hugging Face 生态系统的完全兼容性 。这使得用户即使在免费的 Colab GPU 或配备 GPU 的笔记本电脑等有限硬件上,也能够高效地进行 LLM 微调 。
2. Unsloth 的核心优势
速度提升
在Alpaca 数据集上进行了测试,使用的 batch 大小为 2,gradient accumulation steps 为 4,rank 为 32,并在所有线性层(q、k、v、o、gate、up、down)上应用了 QLoRA
Model | VRAM | Unsloth speed | VRAM reduction | Longer context | Hugging Face + FA2 |
Llama 3.3 (70B) | 80GB | 2x | >75% | 13x longer | 1x |
Llama 3.1 (8B) | 80GB | 2x | >70% | 12x longer | 1x |
测试结果显示,两种模型在显存使用上均为80GB,速度均提升了2倍。Llama 3.3的显存减少了超过75%,能够处理的上下文长度提升了13倍;而Llama 3.1的显存减少了超过70%,上下文长度提升了12倍。
API 简化
Unsloth 提供了一个简洁的 API,显著降低了 LLM 微调的复杂性 。它将模型加载、量化、训练、评估、保存、导出以及与 Ollama、llama.cpp 和 vLLM 等推理引擎的集成等所有工作流程进行了简化 。
Hugging Face 生态系统兼容性
Unsloth 是在 Hugging Face Transformers 库之上构建的,这使其能够充分利用 Hugging Face 强大的生态系统,同时添加自己的增强功能来简化微调过程 。
它与 Hugging Face Hub、Transformers、PEFT 和 TRL 等组件完全兼容 。这意味着用户可以无缝地访问 Hugging Face 提供的丰富模型和数据集资源,并利用 Unsloth 的优化进行训练。
硬件兼容性与可访问性
Unsloth 支持非常多的 NVIDIA GPU,从 2018 年及以后发布的型号,包括 V100、T4、Titan V、RTX 20、30、40 系列、A100 和 H100 等,最低 CUDA 能力要求为 7.0。即使是 GTX 1070 和 1080 也能工作,尽管速度较慢 。
Unsloth 可以在 Linux 和 Windows 操作系统上运行 。
广泛的模型支持
Unsloth 支持所有 Transformer 风格的模型,包括 Llama、DeepSeek、TTS、Qwen、Mistral、Gemma 等主流 LLM。它还支持多模态模型、文本到语音 (TTS)、语音到文本 (STT) 模型、BERT 模型以及扩散模型。
动态量化
Unsloth 引入了动态 4 位量化 (Dynamic 4-bit Quantization) 技术,这是一种智能的量化策略,通过动态选择不对某些参数进行量化,从而在仅增加不到 10% VRAM 的情况下显著提高准确性 。
社区的活跃
从GitHub上可以看出,他的受欢迎程度非常高,目前已经43k颗星
3. 使用要求
支持 Linux 和 Windows。
支持 2018 年及以后的 NVIDIA 显卡,包括 Blackwell RTX 50 系列。最低要求 CUDA 计算能力为 7.0(如 V100、T4、Titan V、RTX 20、30、40 系列,以及 A100、H100、L40 等)。GTX 1070、1080 虽然可以使用,但运行速度较慢。
相关阅读:一文说清楚CUDA环境
GPU内存要求:
Model parameters | QLoRA (4-bit) VRAM | LoRA (16-bit) VRAM |
3B | 3.5 GB | 8 GB |
7B | 5 GB | 19 GB |
8B | 6 GB | 22 GB |
9B | 6.5 GB | 24 GB |
11B | 7.5 GB | 29 GB |
14B | 8.5 GB | 33 GB |
27B | 22GB | 64GB |
32B | 26 GB | 76 GB |
40B | 30GB | 96GB |
70B | 41 GB | 164 GB |
81B | 48GB | 192GB |
90B | 53GB | 212GB |
405B | 237 GB | 950 GB |
4. 安装 Unsloth
# 初始化一个名为 demo-unsloth 的项目,并指定 Python 版本为 3.11.9
uv init demo-unsloth -p 3.11.9
# 进入项目目录
cd demo-unsloth
# 创建虚拟环境
uv venv
# 激活虚拟环境(Windows 下)
.venv\Scripts\activate
# 安装 triton-windows,要求版本3.3.1.post19
uv pip install triton-windows==3.3.1.post19
# 安装支持 CUDA 12.6 的 PyTorch 及其相关库
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
安装unsloth
uv pip install unsloth
5. 选择模型
Unsloth 支持非常多的预训练模型,包括 Llama、DeepSeek、TTS、Qwen、Mistral 和 Gemma 系列 LLM 3。
在选择模型时,需要注意模型名称的后缀。以 unsloth-bnb-4bit
结尾的模型表示它们是 Unsloth 动态 4 位量化模型,与标准 bnb-4bit
模型相比,这些模型在略微增加 VRAM 使用的情况下提供了更高的精度。
https://docs.unsloth.ai/get-started/all-our-models
我的显卡是RTX 2050显卡,4GB显存,我选一个比较小的版本 unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit
下载模型:
huggingface-cli download --resume-download unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit --local-dir ./ckpts/qwen2.5-1.5b-instruct-bnb-4bit
6. 数据集准备
数据集的质量和格式直接决定了模型微调的效果。
一个高质量的数据集不仅需要覆盖目标任务的核心内容,还应保证问答对的准确性、完整性和多样性。
此外,数据集的规模也会影响模型的泛化能力,数据量越大、覆盖面越广,模型在实际应用中的表现通常会更好。
因此,在微调前应充分清洗、标注和检查数据,确保其能够有效支撑下游任务的训练需求。
因为要要训练一个老中医,所以找了一本《伤寒论》,通过工具把他拆成问答对,格式如下
[
{
"Question": "伤寒一日,巨阳受之会出现什么症状?",
"Response": "伤寒一日,巨阳受之会出现头项痛、腰脊强的症状。因为巨阳者,诸阳之属也,其脉连于风府,为诸阳主气,所以先受邪。 "
},
{
"Question": "三阴受病,厥阴受之会出现什么症状",
"Response": "三阴受病,若厥阴受之,厥阴脉循阴器而络于肝,会出现烦满而囊缩的症状。 若伤寒循常无变,十二日厥阴病衰,会出现囊纵、少腹微下的情况。 "
},
]
7. 开始微调
7.1 引入依赖
# 导入必要的库
from unsloth import FastLanguageModel, is_bfloat16_supported
from transformers import TrainingArguments, EarlyStoppingCallback
from trl import SFTTrainer
from datasets import load_dataset
# 模型配置参数
max_seq_length = 2048 # 模型处理文本的最大序列长度,支持长文本输入
7.2 加载模型
# 加载预训练模型和分词器
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="ckpts/qwen2.5-1.5b-instruct-bnb-4bit",
max_seq_length=max_seq_length,
dtype=None, # 自动检测最佳数据类型(bfloat16或float16)
load_in_4bit=True, # 使用4bit量化加载模型,大幅减少显存占用
)
7.3 加载数据集
大型语言模型需要以特定的“提示风格”或“聊天模板”来接收输入,以便它们能够正确理解任务、指令、输入和预期响应之间的关系。
例如,一个常用的提示风格是:
“请先阅读下方的任务指令,再结合提供的上下文信息,生成恰当的回复。### 指令:{} ### 上下文:{} ### 回复:{}”
有效的微调需要仔细关注输入格式,因为它直接影响模型学习的内容和响应方式
下面代码中加载我们的《伤寒论》数据集( data/datasets-2025-08.json
),并使用 formatting_data
函数将其转换为 Unsloth 所需的格式,此函数会结合输入和输出,按照 prompt_style
进行格式化,并添加 EOS_TOKEN
(End-of-Sentence Token),EOS_TOKEN
对于避免模型生成重复内容至关重要。
# 定义训练数据格式化模板
# 使用中医专家角色设定,专门针对《伤寒论》问答任务
train_prompt_style = """你是一位精通中医理论的专家,特别擅长《伤寒论》的理论和实践应用。
请根据《伤寒论》的经典理论,准确回答以下问题。
### 问题:
{}
### 回答:
{}
"""
# 加载训练数据集
# 从JSON文件加载伤寒论问答数据集,包含问题和回答对
dataset = load_dataset("json", data_files="data/datasets-2025-08.json", split="train")
def formatting_data(examples):
"""格式化数据集函数,将问答对转换为训练格式
将原始的问题和回答对格式化为模型训练所需的文本格式,
添加角色设定和结构化模板。
Args:
examples: 包含Question、Response字段的数据样本字典
Returns:
dict: 包含格式化后文本的字典,键为"text"
"""
questions = examples["Question"]
responses = examples["Response"]
texts = []
for q, r in zip(questions, responses):
# 使用模板格式化每个问答对,并添加结束标记
text = train_prompt_style.format(q, r) + tokenizer.eos_token
texts.append(text)
# print(f"数据集: {texts}")
return {"text": texts}
# 应用数据格式化函数
dataset = dataset.map(formatting_data, batched=True, num_proc=1)
# 数据集分割:80%用于训练,20%用于验证
# 使用固定随机种子确保结果可复现
train_test_split = dataset.train_test_split(test_size=0.2, seed=3407)
train_dataset = train_test_split['train']
eval_dataset = train_test_split['test']
print(f"数据集加载完成 - 训练集: {len(train_dataset)} 条, 验证集: {len(eval_dataset)} 条")
7.4 定义 LoRA
# 添加LoRA(Low-Rank Adaptation)权重配置
# LoRA是一种高效的微调方法,只训练少量参数即可适应新任务
model = FastLanguageModel.get_peft_model(
model,
r=32, # LoRA矩阵的秩,值越大表达能力越强,但参数量也越多
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ],
lora_alpha=64, # LoRA缩放参数
lora_dropout=0.1, # LoRA dropout率,防止过拟合,值越小正则化越弱
bias="none", # 偏置项处理方式,"none"表示不训练偏置,节省参数
use_gradient_checkpointing="unsloth", # 使用Unsloth优化的梯度检查点,支持超长序列
random_state=3407, # 随机种子,确保结果可复现
use_rslora=False, # 是否使用Rank Stabilized LoRA,当前使用标准LoRA
loftq_cnotallow=None, # LoftQ配置,用于更精确的量化
)
7.5 使用 SFTTrainer
进行训练
# 创建SFT(Supervised Fine-Tuning)训练器
# 使用监督学习方式微调模型,适用于问答任务
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset, # 训练数据集
eval_dataset=eval_dataset, # 验证数据集,用于评估模型性能
dataset_text_field="text", # 数据集中文本字段的名称
max_seq_length=max_seq_length, # 最大序列长度
dataset_num_proc=1, # 数据处理进程数,设为1避免缓存冲突
packing=False, # 是否使用序列打包,短序列时可设为True提升5倍训练速度
callbacks=[
EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.0005), # 早停机制,防止过拟合
],
args=TrainingArguments(
# 批次大小配置
per_device_train_batch_size=2, # 每个GPU的训练批次大小
per_device_eval_batch_size=2, # 每个GPU的验证批次大小
# 梯度累积配置
gradient_accumulation_steps=8, # 梯度累积步数,有效批次大小 = batch_size * gradient_accumulation_steps
# 学习率配置
warmup_ratio=0.15, # 学习率预热比例,前15%的步数用于预热
learning_rate=2e-5, # 学习率,控制参数更新步长
# 训练轮数和步数配置
# max_steps = 200, # 最大训练步数(可选)
num_train_epochs=5, # 训练轮数,给模型充分学习机会
# 精度配置
fp16=not is_bfloat16_supported(), # 是否使用16位浮点精度训练
bf16=is_bfloat16_supported(), # 是否使用bfloat16精度训练(更稳定)
# 日志和监控配置
logging_steps=2, # 每2步记录一次日志
eval_steps=10, # 每10步进行一次验证评估
eval_strategy="steps", # 按步数进行验证
# 模型保存配置
save_steps=20, # 每20步保存一次模型检查点
save_strategy="steps", # 按步数保存模型
save_total_limit=5, # 最多保存5个检查点,节省存储空间
# 最佳模型配置
load_best_model_at_end=True, # 训练结束时自动加载最佳模型
metric_for_best_model="eval_loss", # 使用验证损失作为最佳模型指标
greater_is_better=False, # 损失越小越好
# 正则化配置
weight_decay=0.001, # 权重衰减,防止过拟合
max_grad_norm=1.0, # 梯度裁剪阈值,防止梯度爆炸
# 学习率调度配置
lr_scheduler_type="cosine", # 使用余弦退火学习率调度器
# 优化器配置
optim="adamw_8bit", # 使用8位AdamW优化器,节省显存
# 数据加载配置
dataloader_num_workers=0, # 数据加载器工作进程数,设为0避免多进程冲突
# 输出配置
output_dir="outputs", # 模型输出和检查点保存目录
# 随机种子
seed=3407, # 随机种子,确保结果可复现
),
)
# 开始训练
train_stats = trainer.train()
print(f"训练完成,训练损失: \n {train_stats}")
7.6 模型保存
Unsloth 提供了多种保存微调后模型的方法,每种方法都有其特定的用途:
save_pretrained
此方法仅保存 LoRA 适配器权重。这些文件通常很小,只包含模型修改部分,包括 adapter_config.json
和 adapter_model.safetensors
。这种方法适用于需要灵活切换不同 LoRA 适配器或节省存储空间的情况。
save_pretrained_gguf
这是一种新的优化格式(GGUF),支持更好的元数据处理。文件同样较小且经过量化,包括 model.gguf
和 tokenizer.json
。Unsloth 的动态量化在 GGUF 导出中通过智能的层特定量化策略,进一步增强了模型性能和效率。
# 保存微调后的模型权重
# 只保存LoRA权重,原始模型权重保持不变
model.save_pretrained("ckpts/qwen2.5-1.5b-instruct-bnb-4bit-lora")
# 保存分词器(tokenizer),以便后续加载和推理时使用
tokenizer.save_pretrained("ckpts/qwen2.5-1.5b-instruct-bnb-4bit-lora")
#==================================================================================
# model.save_pretrained_gguf("ckpts/qwen2.5-1.5b-instruct-bnb-4bit-gguf", tokenizer, quantization_method="q4_k_m")
7.7 训练过程
开始时显示基础信息:
==((====))== Unsloth 2025.7.8: Fast Qwen3 patching. Transformers: 4.53.3.
\\ /| NVIDIA GeForce RTX 2050. Num GPUs = 1. Max memory: 4.0 GB. Platform: Windows.
O^O/ \_/ \ Torch: 2.7.1+cu126. CUDA: 8.6. CUDA Toolkit: 12.6. Triton: 3.3.1
\ / Bfloat16 = TRUE. FA [Xformers = 0.0.31.post1. FA2 = False]
"-____-" Free license: http://github.com/unslothai/unsloth
# 批注:系统配置信息
# - 使用RTX 2050显卡,4GB显存
# - 支持bfloat16精度训练
# - 使用Xformers优化注意力机制
再给出微调的配置:
==((====))== Unsloth - 2x faster free finetuning | Num GPUs used = 1
\\ /| Num examples = 353 | Num Epochs = 5 | Total steps = 115
O^O/ \_/ \ Batch size per device = 2 | Gradient accumulation steps = 8
\ / Data Parallel GPUs = 1 | Total batch size (2 x 8 x 1) = 16
"-____-" Trainable parameters = 20,185,088 of 616,235,008 (3.28% trained)
# 批注:训练配置详情
# - 总样本数:353个
# - 训练轮数:5轮
# - 总步数:115步
# - 有效批次大小:16(2×8×1)
# - 可训练参数:3.28%(20M/616M)
后面是微调细节:
{'loss': 2.1017, 'grad_norm': 3.180413007736206, 'learning_rate': 1.995283421166614e-05, 'epoch': 1.05}
{'loss': 2.2702, 'grad_norm': 2.8307971954345703, 'learning_rate': 1.9869167087338908e-05, 'epoch': 1.14}
{'loss': 2.0511, 'grad_norm': 3.1757583618164062, 'learning_rate': 1.9744105246469264e-05, 'epoch': 1.23}
{'loss': 2.1853, 'grad_norm': 2.5234906673431396, 'learning_rate': 1.957817324187987e-05, 'epoch': 1.32}
{'loss': 1.8145, 'grad_norm': 2.9424679279327393, 'learning_rate': 1.937206705006344e-05, 'epoch': 1.32}
8. 测试微调的模型
模型测试代码如下:
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
from transformers import TextStreamer
import torch
max_seq_length = 512
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="ckpts/qwen2.5-1.5b-instruct-bnb-4bit-lora",
max_seq_length=max_seq_length,
dtype=None,
load_in_4bit=True,
)
# 启用unsloth的2x更快推理优化
print("启用推理优化...")
FastLanguageModel.for_inference(model)
print("模型加载完成")
# 定义训练数据格式化字符串模板
train_prompt_style = """你是一位精通中医理论的专家,特别擅长《伤寒论》的理论和实践应用。
请根据《伤寒论》的经典理论,准确回答以下问题。
### 问题:
{}
"""
question = '什么是太阳病?'
formatted_prompt = train_prompt_style.format(question)
print(f"格式化后的提示文本:\n-------------------\n{formatted_prompt}\n-------------------")
print(type(tokenizer)) # <class 'transformers.models.qwen2.tokenization_qwen2_fast.Qwen2TokenizerFast'>
inputs = tokenizer([formatted_prompt], return_tensors='pt', max_length=max_seq_length).to("cuda")
print(f"inputs: \n-------------------\n{inputs}\n-------------------")
# 生成回答(流式输出)
with torch.no_grad():
outputs = model.generate(
# 输入序列的token ID
inputs['input_ids'],
# 注意力掩码,指示哪些位置是有效输入
attention_mask=inputs['attention_mask'],
# 生成文本的最大长度限制
max_length=max_seq_length,
# 启用KV缓存,加速生成过程
use_cache=True,
# 温度参数,控制生成的随机性
# 值越低生成越确定,值越高生成越随机
temperature=0.8,
# 核采样参数,只从累积概率达到90%的词中选择
# 避免选择概率极低的词,提高生成质量
top_p=0.9,
# 启用采样模式,而不是贪婪解码
# 贪婪解码总是选择概率最高的词,容易产生重复
do_sample=True,
# 设置填充标记ID,用于处理变长序列
pad_token_id=tokenizer.eos_token_id,
# 设置结束标记ID,告诉模型何时停止生成
eos_token_id=tokenizer.eos_token_id,
# 重复惩罚参数,防止模型重复生成相同内容
# 1.1表示重复词的生成概率降低10%
repetition_penalty=1.1
)
print(type(outputs))
print(outputs)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("\n" + "=" * 80)
print(answer)
print("\n" + "=" * 80)
print("测试完成!")
本文转载自AI取经路,作者:AI取经路
