Skill蒸馏完整指南:可复制的实战案例
Skill蒸馏完整指南:可复制的实战案例
本文提供一个完整的Skill蒸馏案例,照着做就能成功。
—
案例目标
用GPT-4生成训练数据,蒸馏出一个”WordPress文章发布助手”。
—
环境准备
pip install openai transformers datasets torch export OPENAI_API_KEY="sk-xxxx" # 你的API Key
—
Step 1:大模型生成训练数据
创建文件 generate_data.py:
from openai import OpenAI
import json
client = OpenAI()
# 定义发布WordPress文章的Prompt模板
PROMPT_TEMPLATE = """你是一个WordPress运营专家。用户给你文章信息,你需要输出完整的发布流程。
文章信息:
- 标题:{title}
- 正文:{content}
- 分类:{category}
- 标签:{tags}
请按以下格式输出发布步骤(只输出步骤,不要解释):
步骤1: 连接XML-RPC接口
地址: https://helloai.jp/xmlrpc.php
用户名: ai-agent
步骤2: 上传封面图片
图片路径: {image_path}
步骤3: 创建文章
标题: {title}
正文: {content}
分类: {category}
步骤4: 发布文章"""
# 准备训练数据
articles = [
{
"title": "AI如何改变2026年的工作方式",
"content": "本文探讨人工智能对未来职场的影响...",
"category": "科技",
"tags": "AI,职场,未来",
"image_path": "./images/ai-work.jpg"
},
# 可以添加更多文章样本
]
# 生成训练数据
training_data = []
for article in articles:
response = client.chat.completions.create(
model="gpt-4",
messages=[
{"role": "system", "content": "你是一个WordPress运营专家。"},
{"role": "user", "content": PROMPT_TEMPLATE.format(**article)}
],
temperature=0.7
)
training_data.append({
"input": f"帮我发布文章:{article['title']}",
"output": response.choices[0].message.content
})
# 保存训练数据
with open("wp_skill_training_data.json", "w", encoding="utf-8") as f:
json.dump(training_data, f, ensure_ascii=False, indent=2)
print(f"生成了 {len(training_data)} 条训练数据")
运行:
python generate_data.py
—
Step 2:格式化数据为训练集
创建文件 prepare_dataset.py:
import json
# 读取原始数据
with open("wp_skill_training_data.json", "r") as f:
data = json.load(f)
# 转换为训练格式
train_data = []
for item in data:
# 构造成指令微调格式
train_data.append({
"instruction": item["input"],
"input": "",
"output": item["output"]
})
# 保存为JSONL(方便后续训练)
with open("train.jsonl", "w", encoding="utf-8") as f:
for item in train_data:
f.write(json.dumps(item, ensure_ascii=False) + "\n")
print(f"转换完成,共 {len(train_data)} 条")
运行:
python prepare_dataset.py
—
Step 3:微调小模型
创建文件 finetune.py:
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from datasets import load_dataset
# 使用小型模型作为基础
model_name = "gpt2" # 可以换成 Qwen/Qwen2-0.5B 等更小的模型
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# 添加pad token
tokenizer.pad_token = tokenizer.eos_token
# 加载数据
dataset = load_dataset("json", data_files="train.jsonl", split="train")
# 数据预处理
def preprocess(example):
prompt = f"指令: {example['instruction']}\n输入: {example['input']}\n输出: {example['output']}"
result = tokenizer(prompt, truncation=True, max_length=512)
result["labels"] = result["input_ids"]
return result
dataset = dataset.map(preprocess, remove_columns=dataset.column_names)
# 训练参数
training_args = TrainingArguments(
output_dir="./wp-skill-model",
num_train_epochs=3,
per_device_train_batch_size=2,
learning_rate=1e-4,
save_steps=100,
logging_steps=50,
)
# 开始训练
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
)
trainer.train()
model.save_pretrained("./wp-skill-model")
tokenizer.save_pretrained("./wp-skill-model")
运行:
python finetune.py
—
Step 4:测试蒸馏后的模型
创建文件 test_model.py:
from transformers import AutoModelForCausalLM, AutoTokenizer
# 加载蒸馏后的模型
model_path = "./wp-skill-model"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)
# 测试输入
test_input = "帮我发布文章:GPT-5新特性解析"
# 生成
inputs = tokenizer(test_input, return_tensors="pt")
outputs = model.generate(
inputs["input_ids"],
max_new_tokens=200,
temperature=0.7,
do_sample=True
)
# 解析结果
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(result)
运行:
python test_model.py
预期输出示例:
步骤1: 连接XML-RPC接口 地址: https://helloai.jp/xmlrpc.php 用户名: ai-agent 步骤2: 上传封面图片 ... 步骤3: 创建文章 ...
—
Step 5:集成到实际应用
创建文件 wp_agent.py:
import xmlrpc.client
from transformers import AutoModelForCausalLM, AutoTokenizer
class WordPressPublisher:
def __init__(self, model_path="./wp-skill-model"):
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForCausalLM.from_pretrained(model_path)
# WordPress连接配置
self.wp_url = "https://helloai.jp/xmlrpc.php"
self.username = "ai-agent"
self.password = "你的应用密码"
def publish(self, title, content, category="未分类"):
# 1. 用蒸馏模型生成发布步骤
steps = self._generate_steps(title, content, category)
# 2. 执行步骤
for step in steps:
self._execute_step(step)
return "发布成功"
def _generate_steps(self, title, content, category):
prompt = f"帮我发布文章:{title}"
inputs = self.tokenizer(prompt, return_tensors="pt")
outputs = self.model.generate(inputs["input_ids"], max_new_tokens=200)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
def _execute_step(self, step):
# 执行XML-RPC操作
server = xmlrpc.client.ServerProxy(self.wp_url)
# 根据step内容执行对应操作...
pass
# 使用
wp = WordPressPublisher()
wp.publish("测试文章", "这是正文内容")
—
关键参数建议
| 参数 | 小数据集(<100条) | 中数据集(100-1000条) | 大数据集(>1000条) |
|——|——————-|———————-|———————|
| 模型大小 | gpt2 / Qwen-0.5B | Qwen-1.8B | Qwen-7B |
| Epochs | 5-10 | 3-5 | 1-3 |
| Batch Size | 1-2 | 4-8 | 8-16 |
| Learning Rate | 1e-4 | 5e-5 | 2e-5 |
—
常见问题
Q: 训练数据不够?
A: 用大模型的few-shot能力生成更多样本。
Q: 效果不好?
A: 检查数据质量,确保input-output对应关系正确。
Q: 显存不够?
A: 使用4-bit量化或更小的模型。
照着这个流程,你就能完成一个完整的Skill蒸馏案例。
