PyTorch 随机 Mask 的技巧
本文记录一下用 PyTorch 随机 Mask 的技巧。
这里假设数值低于 2 的 token 都是特殊 token,不做处理。
1 |
|
其中 mask_arr = (rand < 0.5) * (tensor > 2)
只是一个示例,具体应根据实际情形来调整。
简单测试一下:
1 |
|
输出结果(因为是随机 mask,每个人的输出会有所不同):
1 |
|
凡是出现数值 4 的,都是被 mask 掉的。
下面是更一般的使用方式,注意最后一行用了 mlm
方法。
1 |
|
PyTorch 随机 Mask 的技巧
https://aizpy.com/2023/04/27/pytorch-random-mask/