对 ChatGLM-6B 做 LoRA fine tuning
ChatGLM-6B 是一个支持中英双语的对话语言模型,基于 GLM (General Language Model)。它只有 62 亿个参数,量化后最低 (INT4 量化) 只需要 6GB 的显存,完全可以部署到消费级显卡上。在实际使用这个模型一段时间以后,我们发现模型的对话表现能力确实非常不错。那么,基于这个模型做 Fine-tuning 就非常有价值了。
声明:
本文提供的所有技术信息,都基于 THUDM/chatglm-6b 的历史版本: 096f3de6b4959ce38bef7bb05f3129c931a3084e
。
源码地址:
搭建依赖环境
安装 PyTorch 环境: 1
pip install torch torchvision torchaudio
按照 ChatGLM-6B 的官方指导,安装软件依赖环境: 1
pip install protobuf==3.20.0 transformers==4.26.1 icetk cpm_kernels
为了做 LoRA,还要安装 peft 1
pip install peft
加载模型和 Tokenizer
1 |
|
正如前面声明所述,本文使用的历史版本号是 096f3de6b4959ce38bef7bb05f3129c931a3084e
。如果开发者需要其他版本号,只需要更改 revision
值,并重新训练即可。
分析模型结构
模型加载完后,我们可以打印这个 model
和 tokenizer
,建立对模型的基本认知。
首先打印model
: 1
print(model)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22ChatGLMForConditionalGeneration(
(transformer): ChatGLMModel(
(word_embeddings): Embedding(150528, 4096)
(layers): ModuleList(
(0-27): 28 x GLMBlock(
(input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
(attention): SelfAttention(
(rotary_emb): RotaryEmbedding()
(query_key_value): Linear(in_features=4096, out_features=12288, bias=True)
(dense): Linear(in_features=4096, out_features=4096, bias=True)
)
(post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
(mlp): GLU(
(dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)
(dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)
)
)
)
(final_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
)
(lm_head): Linear(in_features=4096, out_features=150528, bias=False)
)
- 模型使用了 Transformer 结构,因此可以使用 LoRA 进行 Fine-tuning
- 从 Word Embedding 层可以看出,词汇表大小是
150528
- LoRA 可以操作的目标是:
query_key_value
再打印tokenizer
:
1 |
|
得到如下结果(为了便于阅读,已对结果做了分行处理): 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15ChatGLMTokenizer(
name_or_path='THUDM/chatglm-6b',
vocab_size=150344,
model_max_length=2048,
is_fast=False,
padding_side='left',
truncation_side='right',
special_tokens={
'bos_token': '<sop>',
'eos_token': '</s>',
'unk_token': '<unk>',
'pad_token': '<pad>',
'mask_token': '[MASK]'
}
)
这里有几个可以关注的点:
- 词汇表大小
vocab_size
是150344
- 不是一个 fast Tokenizer(
is_fast
的值是False
) - 特殊 token 包括:
bos
eos
pad
和mask
为什么 model 中的词汇表大小是 150528
,而 tokenizer
中定义的词汇表大小却是 150344
呢?读者可以带着这个疑问去读一读模型项目的源码,看看能不能找到答案。
配置 LoRA
借助 peft 库,我们可以很方便地对模型注入 LoRA。 1
2
3
4
5
6
7
8
9
10
11
12
13
14from peft import LoraConfig, get_peft_model, TaskType
def load_lora_config(model):
config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
target_modules=["query_key_value"]
)
return get_peft_model(model, config)
model = load_lora_config(model)1
model.print_trainable_parameters()
1
trainable params: 3670016 || all params: 6258876416 || trainable%: 0.05863697820615348
可以看到,总的参数量是 6,258,876,416
,可训练的参数量是 3,670,016
,占比 0.0586%
左右。训练参数量只是百万级别的,可谓相当友好了!另外需要注意的一点是,ChatGLM-6B 是一个因果语言模型 (Causal Language Model),因此我们这里选择的任务类型是 CAUSAL_LM
。
构建数据集
定义常量
构建之前,我们先定义几个特殊 Token 常量: 1
2
3
4
5bos = tokenizer.bos_token_id
eop = tokenizer.eop_token_id
pad = tokenizer.pad_token_id
mask = tokenizer.mask_token_id
gmask = tokenizer.sp_tokenizer[tokenizer.gMASK_token]
将这几个值打印出来: 1
2
3
4
5print("bos = ", bos)
print("eop = ", eop)
print("pad = ", pad)
print("mask = ", mask)
print("gmask = ", gmask)1
2
3
4
5bos = 150004
eop = 150005
pad = 20003
mask = 150000
gmask = 150001
我们也可以直接用这个常量结果替换动态计算的部分。常量修改后的结果变成: 1
2
3
4
5bos = 150004
eop = 150005
pad = 20003
mask = 150000
gmask = 1500011
2
3device = "cuda"
max_src_length = 200
max_dst_length = 500
测试 Tokenizer 的编解码
我们可以先做个简单的测试: 1
2
3text = "AI探险家"
print(tokenizer.encode(text, add_special_tokens = True))
print(tokenizer.encode(text, add_special_tokens = False))1
2[26738, 98715, 83920, 150001, 150004]
[26738, 98715, 83920]
从这个结果可以看出,“AI探险家”这几个字的裸编码是 [26738, 98715, 83920]
。为什么是这样呢?我们可以对每一个数值再解码,看看输出结果: 1
2
3print(tokenizer.decode([26738]))
print(tokenizer.decode([98715]))
print(tokenizer.decode([83920]))1
2
3AI
探险
家
另外,当 add_special_tokens = True
时,编码结果会在末尾添加 150001
和 150004
,也就是 gmask
和 bos
。请注意,我们的训练数据,要按照如下编码要求进行构造: 1
[token, ..., token, gmask, bos, token, ... token, eop]
add_special_tokens = True
,后半部分文本的编码则让 add_special_tokens = False
,最后再拼接一个 eop
。
定义 Prompt
我们 Fine-tuning 的任务是问答任务(简称 QA),因此一个简单的 Prompt 是这样的:
1 |
|
{}
里填入 QA 训练集的问题文本。在显存有限的情况下,如果不对长文本做限制处理,很容易出现类似 CUDA out of memory
这样的报错。处理长文本,在给定编码后的数组上限时,可能存在这么几种方式:
- 截断末尾超出部分的编码
- 截断前面超出部分的编码
- 丢掉训练样本
每一种方式都有各自的优劣,开发者可以根据自身数据的特点自行选择一种处理方式。当然,如果你的显存够大,也可以不处理。本文以上述第一种方式进行处理。 为了不把 PROMPT_PATTERN
中的 \n答:
这几个字截断掉,我们将整个 PROMPT_PATTERN
拆成两部分:
1 |
|
基于这份 Prompt 模板,我们定义下面三个辅助方法: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37def create_prompt(question):
return PROMPT_PATTERN.format(question), SEP_PATTERN
def create_prompt_ids(tokenizer, question, max_src_length):
prompt, sep = create_prompt(question)
sep_ids = tokenizer.encode(
sep,
add_special_tokens = True
)
sep_len = len(sep_ids)
special_tokens_num = 2
prompt_ids = tokenizer.encode(
prompt,
max_length = max_src_length - (sep_len - special_tokens_num),
truncation = True,
add_special_tokens = False
)
return prompt_ids + sep_ids
def create_inputs_and_labels(tokenizer, question, answer, device):
prompt = create_prompt_ids(tokenizer, question, max_src_length)
completion = tokenizer.encode(
answer,
max_length = max_dst_length,
truncation = True,
add_special_tokens = False
)
inputs = prompt + completion + [eop]
labels = [-100] * len(prompt) + completion + [eop]
inputs = torch.tensor(inputs, dtype=torch.long, device=device)
labels = torch.tensor(labels, dtype=torch.long, device=device)
return inputs, labels
- 从
create_prompt_ids
这个函数实现可以看出,我们编码分隔符SEP_PATTERN
时自动添加了前面所述的 2 个特殊 Token。 - 对
create_inputs_and_labels
的函数实现中,我们将labels
无需处理的部分用数值-100
来表示。因为ChatGLMForConditionalGeneration
内部在计算损失函数的时候,用的是torch.nn.CrossEntropyLoss
。该函数的参数之一ignore_index
默认值是-100
。这就让我们在计算损失函数时,无需考虑非标识部分的数值。
构建 Attention Mask 和 Position IDs
1 |
|
在这个通用实现中,我们针对 mask
和 gmask
两种情况做了区分,同时也对是否执行 position_encoding_2d
分情况处理。本文的 QA 任务采用的是 gmask
,并且使用 position_encoding_2d = True
。
我们可以构建下面的问答,来验证下这几个函数的输出: 1
2
3
4
5
6
7
8
9
10
11
12
13test_data = {
"question": "AI探险家帅不帅?",
"answer": "非常帅!"
}
inputs, labels = create_inputs_and_labels(tokenizer, **test_data, device=device)
attention_mask = get_attention_mask(tokenizer, inputs, device=device)
position_ids = get_position_ids(tokenizer, inputs, device=device)
print("inputs: \n", inputs.tolist())
print("\nlabels: \n", labels.tolist())
print("\nposition_ids: \n", position_ids.tolist())
print("\nattention_mask: \n", attention_mask.tolist())
输出结果(为了便于阅读,已对输出进行格式化操作): 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35inputs:
[20005, 84286, 20012, 31943, 98715, 83920, 87359, 83848, 87359, 20031, 20005, 20004, 87342, 20012, 150001, 150004, 20005, 84122, 87359, 20035, 150005]
labels:
[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 20005, 84122, 87359, 20035, 150005]
position_ids:
[
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5]
]
attention_mask:
[[
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]]]
创建数据集
我们先定义具有如下格式的训练数据: 1
2
3
4
5train_data = [
{"question": "问题1", "answer": "答案1"},
{"question": "问题2", "answer": "答案2"},
]QADataset
类,如下: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31from torch.utils.data import Dataset
class QADataset(Dataset):
def __init__(self, data, tokenizer) -> None:
super().__init__()
self.data = data
self.tokenizer = tokenizer
def __getitem__(self, index):
item_data = self.data[index]
tokenizer = self.tokenizer
input_ids, labels = create_inputs_and_labels(
tokenizer,
device=device,
**item_data
)
attention_mask = get_attention_mask(tokenizer, input_ids, device)
position_ids = get_position_ids(tokenizer, input_ids, device)
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids
}
def __len__(self):
return len(self.data)
然后创建一个 Data Collator:
1 |
|
开始训练
1 |
|
预测
1 |
|
保存训练模型
1 |
|
重载训练后的模型
1 |
|