Google Gemma 2B 微调实战(IT科技新闻标题生成)

本文我将使用 Google 的 Gemma-2b 模型来微调一个基于IT科技新闻正文来生成对应标题的模型。并且我将介绍如何使用高度集成的训练框架来进行快速微调。

开始前

为了尽可能简化整个流程,我将使用 linux-cn 数据集1作为本次训练任务的训练数据。

模型选择使用 Gemma-2b2,在目前这个任务中 2b 级别的参数模型已经完全能满足当前的需求,当然你也可以尝试使用 7b 的模型。

我们在这里将直接使用 LLaMA-Factory3 训练框架来直接完成监督微调部分工作。当然该框架不仅支持监督微调(SFT)也支持预训练(PT)、奖励模型(RM)以及 PPO/DPO 的训练。

数据整理

linux-cn 数据集本身已经进行了数据的清洗和格式化,这一步我们只需要把我们需要的字段提取出后来后根据一定格式转换为 LLaMA-Factory 监督微调格式即可。

在本任务中,我们只需要数据集中的“title”和“content”两个字段即可。而 LLaMA-Factory 监督微调格式是如下格式的json文件。

[
  {
    "instruction": "What are the three primary colors?",
    "input": "",
    "output": "The three primary colors are red, blue, and yellow. These colors are called primary because they cannot be created by mixing other colors and all other colors can be made by combining them in various proportions. In the additive color system, used for light, the primary colors are red, green, and blue (RGB).",
  },
...
]

因为我们选择使用的是预训练模型,所以我们还需要指定一个 prompt template。指定 prompt template 的一个好处是你如果希望同时训练多个不同类型的任务,这样可以保证不同任务之间不会相互干扰。

完整代码如下:

import json

result = []

prompt_template = """Generate a title for the article:

{content}

---
Title:
"""
with open('archve.jsonl', 'r') as f:
    for line in f:
        p = json.loads(line)
        result.append({
            "instruction": prompt_template.replace("{content}", p['content']),
            "input": "",
            "output": p['title']
        })

with open('itnews_data.json', 'w') as f:
    json.dump(result, f,ensure_ascii=False, indent=4)

完成这一步后,我们就可以开始训练我们的模型了。但往往耗费时间最长以及最头疼的也是数据收集和数据整理这一部分。

模型微调

首先你需要保证 LLaMA-Factory 框架已经在你本地已经 ready 了。即你已经下载了该项目并且已经进行了项目的安装。

具体如何安装你可以查看该项目的 README,本文不再过多赘述。

首先我们需要将数据集移动到框架的 data 目录中,然后在 dataset_info.json 中添加我们自定义的数据集。

以下是本文实例所添加的数据集信息:

  "itnews": {
    "file_name": "itnews_data.json",
  },

当然不同类型的任务该框架会有不同的数据集格式要求,你可以参考项目中 dataset_info.json 的README 4

然后我们只需要执行如下命令就可以开始微调了,本文是在单张A100(80G)上进行的微调。

CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
    --stage sft \
    --do_train True \
    --model_name_or_path google/gemma-2b \
    --finetuning_type lora \
    --template default \
    --dataset itnews \
    --use_unsloth \
    --cutoff_len 8192 \
    --learning_rate 5e-05 \
    --num_train_epochs 10.0 \
    --max_samples 10000 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 4 \
    --lr_scheduler_type cosine \
    --max_grad_norm 1.0 \
    --logging_steps 10 \
    --save_steps 100 \
    --eval_steps 100 \
    --evaluation_strategy steps \
    --warmup_steps 0 \
    --output_dir saves/Gemma-2B/lora/train_v1 \
    --bf16 True \
    --lora_rank 8 \
    --lora_dropout 0.1 \
    --lora_target q_proj,v_proj \
    --val_size 0.1 \
    --load_best_model_at_end True \
    --plot_loss True \
    --report_to "tensorboard"

在这里我需要对其中的几个参数进行简短的介绍:

--stage 即任务类型,在这里我们本文做的是监督微调所以是 sft,如果是其他任务你需要指定不同的类型。

--dataset 即数据集,这里的名称就是我们在 dataset_info.json 文件中指定的数据集名称。

--use_unsloth 这是一个训练加速器,官方宣称在 Gemma 7b 上拥有 2.4x 的加速,并且节省超一半的显存。在使用这个之前你需要按照官方文档进行安装5

--cutoff_len 文本令牌化后输入到模型的截止长度,因为本文使用的 Gemma 2b 模型,它的最大长度是 8192 ,所以在这里我设置的是 8192。但请记住更长的上下文也需要更多的 GPU 显存!

--max_samples 设置数据集加载的最大条数。本参数主要用作调试目的时非常好用,尤其是在你不确定 cutoff_lenbatch_size的时候,你可以加载很小的一部分数据进行测试,然后查看你显存的使用情况。

--learning_rate --num_train_epochs 学习率和训练周期,这是一个经验值,一般通过查看模型的 loss 来调整,当然在 LLM 模型训练中,本参数主要以模型是否符合任务需求而决定,也就是说完美的 loss 可能并不满足需求。

--per_device_train_batch_size --per_device_eval_batch_size --gradient_accumulation_steps 这三个参数需要根据你的显存大小以及是否使用多个GPU等条件进行不同的调整。

--output_dir 模型保存的目录。

更多的参数解释可以查看项目说明6,以及 transformers Trainer 的说明7

模型使用

在这里我们可以直接使用 transformers 来执行。

from transformers import AutoModelForCausalLM, AutoTokenizer,BitsAndBytesConfig

peft_model_id = "checkpoint-2000"
model = AutoModelForCausalLM.from_pretrained(peft_model_id,device_map="cuda")

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")

input_text = """
Generate a title for the article:

{content}

---
Title:
""" # 固定格式
encoding = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**encoding,max_length=8192,temperature=0.2,do_sample=True)
generated_ids = outputs[:, encoding.input_ids.shape[1]:]
generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
print(generated_texts[0])

我通过使用我自己的一篇差不多 5000 tokens 关于微服务的文章进行测试,并且这篇文章没有出现在数据集中8

在使用相同 prompt 的情况下的输出:

gemma-2b-it

> 微服务架构
    概述
    微服务架构的定义
    微服务架构的定义
    微服务架构的定义
    微服务架构的定义
    微服务架构的定义
    微服务架构的定义
    ...

lora

> 微服务架构的优势

通过简单的测试,不难发现模型在微调后,其返回格式上更加稳定,并且更加符合我们的要求。

总结

如果你不想训练,但又希望尝试本文中的模型,你可以在 huggingface 上搜索 gemma-2b-technology-news-title-generation-lora,找到从100-2200 steps 的所有 checkpoint 9

本文使用了一种相对简单的方式来训练符合自己需求的模型。在真实的企业场景中往往还涉及如何生成符合需求的数据集,集群训练,模型的AB测试,企业级部署等问题。我会在未来的文章中和大家分享。