stanford cs336 学习笔记
课程的目的
想要真正在 AI 行业做出创新,就必须了解其底层原理,而不是只是调参数、调 API。
AI 大厂正在垄断并封闭先进的 AI 技术,必须学习底层原理,才能复刻这些成果。
在课程中,学生将亲手训练 <1 B 参数模型:数据抓取 → 清洗 →Tokenizer→Transformer→ 训练 → 评估。通过实操体会调度、并行、混精、省显存等工程细节,把平常抽象调用的网络结构拉回到可自主控制修改的范围。
Byte-Pair Encoding (BPE) Tokenizer
The Unicode Standard
unicode 标准给世界所有语言的每个字符指定了一个对应的数字编号
Problem (unicode1): Understanding Unicode (1 point)
(a) What Unicode character does chr(0) return?
1
'\x00'(b) How does this character’s string representation (
__repr__()) differ from its printed representation?1
__repr__() 会返回一个便于程序调试的信息,而print返回一个便于人类可读的信息。比如 chr(0) 的内容就无法用 print 打印,但可以用 repr 实现按 hex 显示(c) What happens when this character occurs in text? It may be helpful to play around with the following in your Python interpreter and see if it matches your expectations:
1 2 3 4
>>> chr(0) >>> print(chr(0)) >>> "this is a test" + chr(0) + "string" >>> print("this is a test" + chr(0) + "string")
1
chr(0)返回的是不可视的字符,所以无法用print()打印
Unicode Encodings
通过 utf-8 或其他编码规则,可以将自然语言编码成计算机能处理的字节串。
Problem (unicode2): Unicode Encodings (3 points)
(a) What are some reasons to prefer training our tokenizer on UTF-8 encoded bytes, rather than UTF-16 or UTF-32? It may be helpful to compare the output of these encodings for various input strings.
1
大部分的都是英文字母,utf-8 长度更短(b) Consider the following (incorrect) function, which is intended to decode a UTF-8 byte string into a Unicode string. Why is this function incorrect? Provide an example of an input byte string that yields incorrect results.
1 2 3 4
def decode_utf8_bytes_to_str_wrong(bytestring: bytes): return "".join([bytes([b]).decode("utf-8") for b in bytestring]) >>> decode_utf8_bytes_to_str_wrong("hello".encode("utf-8")) 'hello'
1 2 3 4 5 6
utf-8是变长的,不能拆分为单字节解码 decode_utf8_bytes_to_str_wrong("你好".encode("utf-8")) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "<stdin>", line 2, in decode_utf8_bytes_to_str_wrong UnicodeDecodeError: 'utf-8' codec can't decode byte 0xe4 in position 0: unexpected end of data
(c) Give a two byte sequence that does not decode to any Unicode character(s).
1
0xFF 0xFF
Subword Tokenization
字节级(byte-level)分词(Tokenization)可以彻底解决词表外(OOV)问题,但会显著拉长输入序列长度,从而增加计算开销并加剧长程依赖,导致模型训练变慢、建模更困难。
子词(subword)分词是介于词级和字节级之间的折中方案:通过使用更大的词表来压缩字节序列长度。字节级分词的基础词表只有 256 个字节值,而子词分词会把高频出现的字节序列(如 the)合并成一个 token,从而减少序列长度。BPE(Byte-Pair Encoding)通过反复合并最常见的字节对来构建这些子词单元,使频繁出现的词或片段逐渐成为单个 token。基于 BPE 的子词分词器在保持良好 OOV 处理能力的同时,显著缩短序列长度;构建其词表的过程称为 BPE tokenizer 的训练。
BPE Tokenizer Training
训练 BPE 的三个步骤:
Vocabulary initialization(词表初始化)
词表是从 token 到整型 id 的映射表。这个表在训练前的默认值非常简单,就是每个 byte(ascii) 到对应的数字,比如
'a'(0x61) -> 97,共 256 条。Pre-tokenization(预标记)
先提取出所有的单词(特殊组合),统计每个单词出现的次数,然后再匹配每个字母字节的个数及其邻近关系
常用的预标记正则:
1
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
(?:...)问号和冒号放在括号里表示这个捕获组仅分组,而不产生分组编号,无法使用$1之类的方式访问'(?:[sdmt]|ll|ve|re)表示匹配所有'和后面的特定词,比如'll(I’ll)、'm(I’m)、't(don’t)等(所有格)?\p{L}+表示匹配 0 到 1 个空格,然后\p{L}+表示按照 L 规则的 Unicode 类型匹配字符,可以匹配所有字母,结果就是一个空格加一个单词?\p{N}+同上,表示按照 N 规则匹配,可以匹配所有数字,结果就是一个空格加一串数字?[^\s\p{L}\p{N}]+表示匹配空格,然后不匹配所有的空格、字母、数字,其实就是匹配所有标点符号和表情符号、运算符等\s+(?!\S)表示匹配多个连续的空格,但后面不能有非空格(\S),所以只能匹配段落末尾的连续空格,中间的空格尾部都会跟字母或者数字。\s+表示什么
这个正则相当于是英语语句的基本语法,比如用空格分隔单词,用
'm这种表示所有格(possessive)或缩写最后会得到一张预标记表:每一条的格式是[预标记,数量],也就是每个词出现的次数: 比如
{"the": 20, "+": 2, "dog": 5, "'ll": 1}Compute BPE merges(计算 BPE 合并)
根据预标记表,统计词表中的哪两个 token (第一次处理这些词都是初始化词表中单个字节的字符)能拼在一起(连续出现的概率较高)
比如一个句子中有很多的
the,则t h和h e两个组合出现概率高,就会拼在一起作为一个新的 token 加入词表。此时词表中同时有tethhe重复这个操作,大概率
th将会和e拼接为the加入词表。此时词表中同时有tethhethe
Experimenting with BPE Tokenizer Training
现在用 TinyStories 数据集训练一个 BPE 模型
Pre-tokenization 比较耗时,可以用已有的分块并行化处理函数:
Problem (train_bpe): BPE Tokenizer Training (15 points)
使用 uv run pytest tests/test_train_bpe.py 运行测试
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
def run_train_bpe(
input_path: str | os.PathLike,
vocab_size: int,
special_tokens: list[str],
**kwargs,
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
"""Given the path to an input corpus, run train a BPE tokenizer and
output its vocabulary and merges.
Args:
input_path (str | os.PathLike): Path to BPE tokenizer training data.
vocab_size (int): Total number of items in the tokenizer's vocabulary (including special tokens).
special_tokens (list[str]): A list of string special tokens to be added to the tokenizer vocabulary.
These strings will never be split into multiple tokens, and will always be
kept as a single token. If these special tokens occur in the `input_path`,
they are treated as any other string.
Returns:
tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
vocab:
The trained tokenizer vocabulary, a mapping from int (token ID in the vocabulary)
to bytes (token bytes)
merges:
BPE merges. Each list item is a tuple of bytes (<token1>, <token2>),
representing that <token1> was merged with <token2>.
Merges are ordered by order of creation.
"""
vocab = {}
merges = []
# vocabulary 初始化
vocab_index = 0
for special_token in special_tokens:
vocab[vocab_index] = special_token.encode()
vocab_index = vocab_index + 1
for i in range(0,256):
vocab[vocab_index] = bytes([i])
vocab_index = vocab_index + 1
with open(input_path, "r", encoding="utf-8") as f:
content = f.read()
# 去掉特殊字符(拆分文章)
special_tokens_pattern = "|".join(map(re.escape, special_tokens))
parts = re.split(special_tokens_pattern, content)
# Pre-tokenization
pre_tokens = {}
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
for part in parts:
matches = re.finditer(PAT, part)
for val in matches:
pre_token = val.group().encode()
pre_tokens[pre_token] = pre_tokens.get(pre_token, 0) + 1
# Compute BPE merges
token_pairs = {}
# 1. 提取所有单字节的pair
for pre_token, count in pre_tokens.items():
for i in range(len(pre_token)-1):
token_pair = (pre_token[i].to_bytes(), pre_token[i+1].to_bytes())
token_pairs[token_pair] = token_pairs.get(token_pair, 0) + count
# 2. 根据已经有的pair更新计数
while True:
max_token_pair = tuple()
max_token_count = 0
for token_pair, count in token_pairs.items():
if count > max_token_count:
exsited = False
for merge_pair in merges:
if merge_pair[0] + merge_pair[1] == token_pair[0] + token_pair[1]:
exsited = True
break
if exsited:
continue
max_token_count = count
max_token_pair = token_pair
token_pairs[max_token_pair] = 0
vocab[vocab_index] = max_token_pair[0] + max_token_pair[1]
vocab_index = vocab_index + 1
merges.append(max_token_pair)
if vocab_index >= vocab_size:
break
for pre_token, count in pre_tokens.items():
pos = 0
while True:
max_token_bytes = max_token_pair[0] + max_token_pair[1]
pos = pre_token.find(max_token_bytes, pos)
if pos == -1:
break
if pos >= 1:
token_pair = (pre_token[pos - 1:pos], max_token_bytes)
token_pairs[token_pair] = token_pairs.get(token_pair, 0) + count
if pos + len(max_token_bytes) < len(pre_token):
token_pair = (max_token_bytes, pre_token[pos + len(max_token_bytes):pos + len(max_token_bytes)+1])
token_pairs[token_pair] = token_pairs.get(token_pair, 0) + count
for merge_pair in merges:
merge_pair_token = merge_pair[0] + merge_pair[1]
if pos >= len(merge_pair_token):
if pre_token[pos - len(merge_pair_token):pos] == merge_pair_token:
token_pair = (merge_pair_token, max_token_bytes)
token_pairs[token_pair] = token_pairs.get(token_pair, 0) + count
if pos < len(pre_token) - len(merge_pair_token):
if pre_token[pos + len(max_token_bytes):pos + len(max_token_bytes)+len(merge_pair_token)] == merge_pair_token:
token_pair = (max_token_bytes, merge_pair_token)
token_pairs[token_pair] = token_pairs.get(token_pair, 0) + count
pos = pos + len(max_token_bytes)
return (vocab, merges)

