文章

stanford cs336 学习笔记

stanford cs336 学习笔记

课程的目的

想要真正在 AI 行业做出创新,就必须了解其底层原理,而不是只是调参数、调 API。

AI 大厂正在垄断并封闭先进的 AI 技术,必须学习底层原理,才能复刻这些成果。

在课程中,学生将亲手训练 <1 B 参数模型:数据抓取 → 清洗 →Tokenizer→Transformer→ 训练 → 评估。通过实操体会调度、并行、混精、省显存等工程细节,把平常抽象调用的网络结构拉回到可自主控制修改的范围。

alt text

alt text

Byte-Pair Encoding (BPE) Tokenizer

assignment1

The Unicode Standard

unicode 标准给世界所有语言的每个字符指定了一个对应的数字编号

Problem (unicode1): Understanding Unicode (1 point)

  • (a) What Unicode character does chr(0) return?

    1
    
    '\x00'
    
  • (b) How does this character’s string representation (__repr__()) differ from its printed representation?

    1
    
    __repr__() 会返回一个便于程序调试的信息,而print返回一个便于人类可读的信息。比如 chr(0) 的内容就无法用 print 打印,但可以用 repr 实现按 hex 显示
    
  • (c) What happens when this character occurs in text? It may be helpful to play around with the following in your Python interpreter and see if it matches your expectations:

    1
    2
    3
    4
    
    >>> chr(0)
    >>> print(chr(0))
    >>> "this is a test" + chr(0) + "string"
    >>> print("this is a test" + chr(0) + "string")
    
    1
    
    chr(0)返回的是不可视的字符,所以无法用print()打印
    

Unicode Encodings

通过 utf-8 或其他编码规则,可以将自然语言编码成计算机能处理的字节串。

Problem (unicode2): Unicode Encodings (3 points)

  • (a) What are some reasons to prefer training our tokenizer on UTF-8 encoded bytes, rather than UTF-16 or UTF-32? It may be helpful to compare the output of these encodings for various input strings.

    1
    
    大部分的都是英文字母,utf-8 长度更短
    
  • (b) Consider the following (incorrect) function, which is intended to decode a UTF-8 byte string into a Unicode string. Why is this function incorrect? Provide an example of an input byte string that yields incorrect results.

    1
    2
    3
    4
    
    def decode_utf8_bytes_to_str_wrong(bytestring: bytes):
      return "".join([bytes([b]).decode("utf-8") for b in bytestring])
    >>> decode_utf8_bytes_to_str_wrong("hello".encode("utf-8"))
      'hello'
    
    1
    2
    3
    4
    5
    6
    
    utf-8是变长的,不能拆分为单字节解码
    decode_utf8_bytes_to_str_wrong("你好".encode("utf-8"))
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
      File "<stdin>", line 2, in decode_utf8_bytes_to_str_wrong
    UnicodeDecodeError: 'utf-8' codec can't decode byte 0xe4 in position 0: unexpected end of data
    
  • (c) Give a two byte sequence that does not decode to any Unicode character(s).

    1
    
    0xFF 0xFF
    

Subword Tokenization

字节级(byte-level)分词(Tokenization)可以彻底解决词表外(OOV)问题,但会显著拉长输入序列长度,从而增加计算开销并加剧长程依赖,导致模型训练变慢、建模更困难。

子词(subword)分词是介于词级和字节级之间的折中方案:通过使用更大的词表来压缩字节序列长度。字节级分词的基础词表只有 256 个字节值,而子词分词会把高频出现的字节序列(如 the)合并成一个 token,从而减少序列长度。BPE(Byte-Pair Encoding)通过反复合并最常见的字节对来构建这些子词单元,使频繁出现的词或片段逐渐成为单个 token。基于 BPE 的子词分词器在保持良好 OOV 处理能力的同时,显著缩短序列长度;构建其词表的过程称为 BPE tokenizer 的训练。

BPE Tokenizer Training

训练 BPE 的三个步骤:

  • Vocabulary initialization(词表初始化)

    词表是从 token 到整型 id 的映射表。这个表在训练前的默认值非常简单,就是每个 byte(ascii) 到对应的数字,比如 'a'(0x61) -> 97,共 256 条。

  • Pre-tokenization(预标记)

    先提取出所有的单词(特殊组合),统计每个单词出现的次数,然后再匹配每个字母字节的个数及其邻近关系

    常用的预标记正则:

    1
    
    PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
    
    • (?:...) 问号和冒号放在括号里表示这个捕获组仅分组,而不产生分组编号,无法使用 $1 之类的方式访问
    • '(?:[sdmt]|ll|ve|re) 表示匹配所有 ' 和后面的特定词,比如 'll(I’ll)、'm(I’m)、't(don’t)等(所有格)
    • ?\p{L}+ 表示匹配 0 到 1 个空格,然后 \p{L}+ 表示按照 L 规则的 Unicode 类型匹配字符,可以匹配所有字母,结果就是一个空格加一个单词
    • ?\p{N}+ 同上,表示按照 N 规则匹配,可以匹配所有数字,结果就是一个空格加一串数字
    • ?[^\s\p{L}\p{N}]+ 表示匹配空格,然后不匹配所有的空格、字母、数字,其实就是匹配所有标点符号和表情符号、运算符等
    • \s+(?!\S) 表示匹配多个连续的空格,但后面不能有非空格(\S),所以只能匹配段落末尾的连续空格,中间的空格尾部都会跟字母或者数字。
    • \s+ 表示什么

    这个正则相当于是英语语句的基本语法,比如用空格分隔单词,用 'm 这种表示所有格(possessive)或缩写

    最后会得到一张预标记表:每一条的格式是[预标记,数量],也就是每个词出现的次数: 比如 {"the": 20, "+": 2, "dog": 5, "'ll": 1}

  • Compute BPE merges(计算 BPE 合并)

    根据预标记表,统计词表中的哪两个 token (第一次处理这些词都是初始化词表中单个字节的字符)能拼在一起(连续出现的概率较高)

    比如一个句子中有很多的 the,则t hh e两个组合出现概率高,就会拼在一起作为一个新的 token 加入词表。此时词表中同时有 t e th he

    重复这个操作,大概率 th 将会和 e 拼接为 the 加入词表。此时词表中同时有 t e th he the

Experimenting with BPE Tokenizer Training

现在用 TinyStories 数据集训练一个 BPE 模型

Pre-tokenization 比较耗时,可以用已有的分块并行化处理函数:

https://github.com/stanford-cs336/assignment1-basics/blob/main/cs336_basics/pretokenization_example.py

Problem (train_bpe): BPE Tokenizer Training (15 points)

使用 uv run pytest tests/test_train_bpe.py 运行测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
  def run_train_bpe(
      input_path: str | os.PathLike,
      vocab_size: int,
      special_tokens: list[str],
      **kwargs,
  ) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
      """Given the path to an input corpus, run train a BPE tokenizer and
      output its vocabulary and merges.

      Args:
          input_path (str | os.PathLike): Path to BPE tokenizer training data.
          vocab_size (int): Total number of items in the tokenizer's vocabulary (including special tokens).
          special_tokens (list[str]): A list of string special tokens to be added to the tokenizer vocabulary.
              These strings will never be split into multiple tokens, and will always be
              kept as a single token. If these special tokens occur in the `input_path`,
              they are treated as any other string.

      Returns:
          tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
              vocab:
                  The trained tokenizer vocabulary, a mapping from int (token ID in the vocabulary)
                  to bytes (token bytes)
              merges:
                  BPE merges. Each list item is a tuple of bytes (<token1>, <token2>),
                  representing that <token1> was merged with <token2>.
                  Merges are ordered by order of creation.
      """
      vocab = {}
      merges = []
      # vocabulary 初始化
      vocab_index = 0
      for special_token in special_tokens:
          vocab[vocab_index] = special_token.encode()
          vocab_index = vocab_index + 1
      for i in range(0,256):
          vocab[vocab_index] = bytes([i])
          vocab_index = vocab_index + 1
      
      with open(input_path, "r", encoding="utf-8") as f:
          content = f.read()
      # 去掉特殊字符(拆分文章)
      special_tokens_pattern = "|".join(map(re.escape, special_tokens))
      parts = re.split(special_tokens_pattern, content)
      # Pre-tokenization
      pre_tokens = {}
      PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
      for part in parts:
          matches = re.finditer(PAT, part)
          for val in matches:
              pre_token = val.group().encode()
              pre_tokens[pre_token] = pre_tokens.get(pre_token, 0) + 1
      # Compute BPE merges
      token_pairs = {}
      # 1. 提取所有单字节的pair
      for pre_token, count in pre_tokens.items():
          for i in range(len(pre_token)-1):
              token_pair = (pre_token[i].to_bytes(), pre_token[i+1].to_bytes())
              token_pairs[token_pair] = token_pairs.get(token_pair, 0) + count

      # 2. 根据已经有的pair更新计数
      while True:
          max_token_pair = tuple()
          max_token_count = 0
          for token_pair, count in token_pairs.items():
              if count > max_token_count:
                  exsited = False
                  for merge_pair in merges:
                      if merge_pair[0] + merge_pair[1] == token_pair[0] + token_pair[1]:
                          exsited = True
                          break
                  if exsited:
                      continue
                  max_token_count = count
                  max_token_pair = token_pair
          token_pairs[max_token_pair] = 0
          vocab[vocab_index] = max_token_pair[0] + max_token_pair[1]
          vocab_index = vocab_index + 1
          merges.append(max_token_pair)

          if vocab_index >= vocab_size:
              break
          for pre_token, count in pre_tokens.items():
              pos = 0
              while True:
                  max_token_bytes = max_token_pair[0] + max_token_pair[1]
                  pos = pre_token.find(max_token_bytes, pos)
                  if pos == -1:
                      break
                  if pos >= 1:
                      token_pair = (pre_token[pos - 1:pos], max_token_bytes)
                      token_pairs[token_pair] = token_pairs.get(token_pair, 0) + count
                  if pos + len(max_token_bytes) < len(pre_token):
                      token_pair = (max_token_bytes, pre_token[pos + len(max_token_bytes):pos + len(max_token_bytes)+1])
                      token_pairs[token_pair] = token_pairs.get(token_pair, 0) + count

                  for merge_pair in merges:
                      merge_pair_token = merge_pair[0] + merge_pair[1]
                      if pos >= len(merge_pair_token):
                          if pre_token[pos - len(merge_pair_token):pos] == merge_pair_token:
                              token_pair = (merge_pair_token, max_token_bytes)
                              token_pairs[token_pair] = token_pairs.get(token_pair, 0) + count
                      if pos < len(pre_token) - len(merge_pair_token):
                          if pre_token[pos + len(max_token_bytes):pos + len(max_token_bytes)+len(merge_pair_token)] == merge_pair_token:
                              token_pair = (max_token_bytes, merge_pair_token)
                              token_pairs[token_pair] = token_pairs.get(token_pair, 0) + count

                  pos = pos + len(max_token_bytes)

      return (vocab, merges)

参考

本文由作者按照 CC BY-SA 4.0 进行授权