对 ChatGLM-6B 做 LoRA fine tuning
ChatGLM-6B 是一个支持中英双语的对话语言模型,基于 GLM (General Language Model)。它只有 62 亿个参数,量化后最低 (INT4 量化) 只需要 6GB 的显存,完全可以部署到消费级显卡上。在实际使用这个模型一段时间以后,我们发现模型的对话表现能力确实非常不错。那么,基于这个模型做 Fine-tuning 就非常有价值了。
声明:
本文提供的所有技术信息,都基于 THUDM/chatglm-6b 的历史版本: 096f3de6b4959ce38bef7bb05f3129c931a3084e。
源码地址:
搭建依赖环境
安装 PyTorch 环境: 1
pip install torch torchvision torchaudio
按照 ChatGLM-6B 的官方指导,安装软件依赖环境: 1
pip install protobuf==3.20.0 transformers==4.26.1 icetk cpm_kernels  
为了做 LoRA,还要安装 peft 1
pip install peft
加载模型和 Tokenizer
| 1 |  | 
正如前面声明所述,本文使用的历史版本号是 096f3de6b4959ce38bef7bb05f3129c931a3084e。如果开发者需要其他版本号,只需要更改 revision 值,并重新训练即可。
分析模型结构
模型加载完后,我们可以打印这个 model 和 tokenizer,建立对模型的基本认知。
首先打印model: 1
print(model)1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22ChatGLMForConditionalGeneration(
  (transformer): ChatGLMModel(
    (word_embeddings): Embedding(150528, 4096)
    (layers): ModuleList(
      (0-27): 28 x GLMBlock(
        (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attention): SelfAttention(
          (rotary_emb): RotaryEmbedding()
          (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)
          (dense): Linear(in_features=4096, out_features=4096, bias=True)
        )
        (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (mlp): GLU(
          (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)
          (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)
        )
      )
    )
    (final_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=4096, out_features=150528, bias=False)
)
- 模型使用了 Transformer 结构,因此可以使用 LoRA 进行 Fine-tuning
- 从 Word Embedding 层可以看出,词汇表大小是 150528
- LoRA 可以操作的目标是:query_key_value
再打印tokenizer:
| 1 |  | 
得到如下结果(为了便于阅读,已对结果做了分行处理): 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15ChatGLMTokenizer(
	name_or_path='THUDM/chatglm-6b', 
	vocab_size=150344, 
	model_max_length=2048, 
	is_fast=False, 
	padding_side='left', 
	truncation_side='right', 
	special_tokens={
		'bos_token': '<sop>', 
		'eos_token': '</s>', 
		'unk_token': '<unk>', 
		'pad_token': '<pad>', 
		'mask_token': '[MASK]'
	}
)
这里有几个可以关注的点:
- 词汇表大小vocab_size是150344
- 不是一个 fast Tokenizer(is_fast的值是False)
- 特殊 token 包括:boseospad和mask
为什么 model 中的词汇表大小是 150528,而 tokenizer 中定义的词汇表大小却是 150344 呢?读者可以带着这个疑问去读一读模型项目的源码,看看能不能找到答案。
配置 LoRA
借助 peft 库,我们可以很方便地对模型注入 LoRA。 1
2
3
4
5
6
7
8
9
10
11
12
13
14from peft import LoraConfig, get_peft_model, TaskType
def load_lora_config(model):
	config = LoraConfig(
	    task_type=TaskType.CAUSAL_LM, 
	    inference_mode=False,
	    r=8, 
	    lora_alpha=32, 
	    lora_dropout=0.1,
	    target_modules=["query_key_value"]
	)
	return get_peft_model(model, config)
model = load_lora_config(model)1
model.print_trainable_parameters()1
trainable params: 3670016 || all params: 6258876416 || trainable%: 0.05863697820615348
可以看到,总的参数量是 6,258,876,416,可训练的参数量是 3,670,016,占比 0.0586% 左右。训练参数量只是百万级别的,可谓相当友好了!另外需要注意的一点是,ChatGLM-6B 是一个因果语言模型 (Causal Language Model),因此我们这里选择的任务类型是 CAUSAL_LM。
构建数据集
定义常量
构建之前,我们先定义几个特殊 Token 常量: 1
2
3
4
5bos = tokenizer.bos_token_id
eop = tokenizer.eop_token_id
pad = tokenizer.pad_token_id
mask = tokenizer.mask_token_id
gmask = tokenizer.sp_tokenizer[tokenizer.gMASK_token]
将这几个值打印出来: 1
2
3
4
5print("bos = ", bos)
print("eop = ", eop)
print("pad = ", pad)
print("mask = ", mask)
print("gmask = ", gmask)1
2
3
4
5bos =  150004
eop =  150005
pad =  20003
mask =  150000
gmask =  150001
我们也可以直接用这个常量结果替换动态计算的部分。常量修改后的结果变成: 1
2
3
4
5bos = 150004
eop = 150005
pad = 20003
mask = 150000
gmask = 1500011
2
3device = "cuda"
max_src_length = 200
max_dst_length = 500
测试 Tokenizer 的编解码
我们可以先做个简单的测试: 1
2
3text = "AI探险家"
print(tokenizer.encode(text, add_special_tokens = True))
print(tokenizer.encode(text, add_special_tokens = False))1
2[26738, 98715, 83920, 150001, 150004]
[26738, 98715, 83920]
从这个结果可以看出,“AI探险家”这几个字的裸编码是 [26738, 98715, 83920]。为什么是这样呢?我们可以对每一个数值再解码,看看输出结果: 1
2
3print(tokenizer.decode([26738]))
print(tokenizer.decode([98715]))
print(tokenizer.decode([83920]))1
2
3AI
探险
家
另外,当 add_special_tokens = True 时,编码结果会在末尾添加 150001和 150004,也就是 gmask 和 bos。请注意,我们的训练数据,要按照如下编码要求进行构造: 1
[token, ..., token, gmask, bos, token, ... token, eop]add_special_tokens = True,后半部分文本的编码则让 add_special_tokens = False,最后再拼接一个 eop。
定义 Prompt
我们 Fine-tuning 的任务是问答任务(简称 QA),因此一个简单的 Prompt 是这样的:
| 1 |  | 
{}里填入 QA 训练集的问题文本。在显存有限的情况下,如果不对长文本做限制处理,很容易出现类似 CUDA out of memory 这样的报错。处理长文本,在给定编码后的数组上限时,可能存在这么几种方式:
- 截断末尾超出部分的编码
- 截断前面超出部分的编码
- 丢掉训练样本
每一种方式都有各自的优劣,开发者可以根据自身数据的特点自行选择一种处理方式。当然,如果你的显存够大,也可以不处理。本文以上述第一种方式进行处理。 为了不把 PROMPT_PATTERN 中的 \n答: 这几个字截断掉,我们将整个 PROMPT_PATTERN 拆成两部分:
| 1 |  | 
基于这份 Prompt 模板,我们定义下面三个辅助方法: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37def create_prompt(question):
    return PROMPT_PATTERN.format(question), SEP_PATTERN
def create_prompt_ids(tokenizer, question, max_src_length):
    prompt, sep = create_prompt(question)
    sep_ids = tokenizer.encode(
        sep, 
        add_special_tokens = True
    )
    sep_len = len(sep_ids)
    special_tokens_num = 2
    prompt_ids = tokenizer.encode(
        prompt, 
        max_length = max_src_length - (sep_len - special_tokens_num),
        truncation = True,
        add_special_tokens = False
    )
    return prompt_ids + sep_ids
def create_inputs_and_labels(tokenizer, question, answer, device):
    prompt = create_prompt_ids(tokenizer, question, max_src_length)
    completion = tokenizer.encode(
        answer, 
        max_length = max_dst_length,
        truncation = True,
        add_special_tokens = False
    )
    inputs = prompt + completion + [eop]
    labels = [-100] * len(prompt) + completion + [eop] 
    
    inputs = torch.tensor(inputs, dtype=torch.long, device=device)
    labels = torch.tensor(labels, dtype=torch.long, device=device)
    return inputs, labels
- 从 create_prompt_ids这个函数实现可以看出,我们编码分隔符SEP_PATTERN时自动添加了前面所述的 2 个特殊 Token。
- 对 create_inputs_and_labels的函数实现中,我们将labels无需处理的部分用数值-100来表示。因为ChatGLMForConditionalGeneration内部在计算损失函数的时候,用的是torch.nn.CrossEntropyLoss。该函数的参数之一ignore_index默认值是-100。这就让我们在计算损失函数时,无需考虑非标识部分的数值。
构建 Attention Mask 和 Position IDs
| 1 |  | 
在这个通用实现中,我们针对 mask 和 gmask 两种情况做了区分,同时也对是否执行 position_encoding_2d 分情况处理。本文的 QA 任务采用的是 gmask,并且使用 position_encoding_2d = True。
我们可以构建下面的问答,来验证下这几个函数的输出: 1
2
3
4
5
6
7
8
9
10
11
12
13test_data = {
	"question": "AI探险家帅不帅?",
	"answer": "非常帅!"
}
inputs, labels = create_inputs_and_labels(tokenizer, **test_data, device=device)
attention_mask = get_attention_mask(tokenizer, inputs, device=device)
position_ids = get_position_ids(tokenizer, inputs, device=device)
print("inputs: \n", inputs.tolist())
print("\nlabels: \n", labels.tolist())
print("\nposition_ids: \n", position_ids.tolist())
print("\nattention_mask: \n", attention_mask.tolist())
输出结果(为了便于阅读,已对输出进行格式化操作): 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35inputs: 
 [20005, 84286, 20012, 31943, 98715, 83920, 87359, 83848, 87359, 20031, 20005, 20004, 87342, 20012, 150001, 150004, 20005, 84122, 87359, 20035, 150005]
labels: 
 [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 20005, 84122, 87359, 20035, 150005]
position_ids: 
 [
 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], 
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0,  0,  0,  0,  0,  1,  2,  3,  4,  5]
 ]
attention_mask: 
 [[
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True], 
 [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]]]
创建数据集
我们先定义具有如下格式的训练数据: 1
2
3
4
5train_data = [
	{"question": "问题1", "answer": "答案1"},
	{"question": "问题2", "answer": "答案2"},
]QADataset 类,如下: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31from torch.utils.data import Dataset
class QADataset(Dataset):
    def __init__(self, data, tokenizer) -> None:
        super().__init__()
        self.data = data
        self.tokenizer = tokenizer
 
    def __getitem__(self, index):
        item_data = self.data[index]
        tokenizer = self.tokenizer
        input_ids, labels = create_inputs_and_labels(
            tokenizer, 
            device=device,
            **item_data
        )
        
        attention_mask = get_attention_mask(tokenizer, input_ids, device)
        position_ids = get_position_ids(tokenizer, input_ids, device)
        return {
            "input_ids": input_ids,
            "labels": labels,
            "attention_mask": attention_mask,
            "position_ids": position_ids
        }
        
    def __len__(self):
        return len(self.data)
然后创建一个 Data Collator:
| 1 |  | 
开始训练
| 1 |  | 
预测
| 1 |  | 
保存训练模型
| 1 |  | 
重载训练后的模型
| 1 |  |