PyTorch 随机 Mask 的技巧

本文记录一下用 PyTorch 随机 Mask 的技巧。

这里假设数值低于 2 的 token 都是特殊 token,不做处理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch

# 定义 mask 的 token id
mask_token_id = 4

def mlm(tensor):
# 克隆一份数据,避免修改原始输入数据
tensor = tensor.detach().clone()

rand = torch.rand(tensor.shape)
# 50% 的概率随机 mask
# 忽略掉数值低于 2 的特殊 token
mask_arr = (rand < 0.5) * (tensor > 2)
for i in range(tensor.shape[0]):
selection = torch.flatten(mask_arr[i].nonzero()).tolist()
tensor[i, selection] = mask_token_id

return tensor

其中 mask_arr = (rand < 0.5) * (tensor > 2) 只是一个示例,具体应根据实际情形来调整。

简单测试一下:

1
2
3
4
5
6
7
8
9
10
samples = torch.tensor([
[0, 1652, 233, 3252, 1234, 634, 1, 1, 1, 1],
[0, 223, 1530, 232, 4134, 832, 20, 1, 1, 1],
])

labels = samples
input_ids = mlm(samples)

print("labels: \n", labels)
print("input_ids: \n", input_ids)

输出结果(因为是随机 mask,每个人的输出会有所不同):

1
2
3
4
5
6
7
8
9
10
11
labels:
tensor([
[0, 1652, 233, 3252, 1234, 634, 1, 1, 1, 1],
[0, 223, 1530, 232, 4134, 832, 20, 1, 1, 1]
])

input_ids:
tensor([
[0, 4, 4, 4, 1234, 634, 1, 1, 1, 1],
[0, 4, 4, 232, 4, 832, 4, 1, 1, 1]
])

凡是出现数值 4 的,都是被 mask 掉的。

下面是更一般的使用方式,注意最后一行用了 mlm 方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from tqdm.auto import tqdm

paths = ["sample_0.txt", "sample_1.txt", "sample_2.txt", "sample_3.txt", "sample_4.txt"]

input_ids = []
mask = []
labels = []

for path in tqdm(paths):
with open(path, 'r', encoding='utf-8') as f:
lines = f.read().split('\n')
sample = tokenizer(lines, max_length=512, padding='max_length', trunction=True)
labels.append(sample.input_ids)
mask.append(sample.attention_mask)
input_ids.append(mlm(sample.input_ids))

PyTorch 随机 Mask 的技巧
https://aizpy.com/2023/04/27/pytorch-random-mask/
作者
aizpy
发布于
2023年4月27日
许可协议