11import string
22from heapq import heappush , heappop
3+ from typing import List , Tuple
34
45from labml import lab , monit
56
@@ -25,33 +26,20 @@ def calc_bpe_itos(self):
2526 return itos
2627
2728
28- class BPELearner :
29- def __init__ (self , data : str ):
30- self .data = data
31- self .words = {}
32- self .heap = []
33- self .heap_modified = set ()
34- self .char_itos = []
35- self .char_stoi = {}
36- self .bpe = []
37- self .word_codes = []
38- self .word_code_prev = {}
39- self .word_code_next = {}
29+ class Tokenizer :
30+ def collect_words (self , data : str ):
31+ raise NotImplementedError
4032
41- self . counts = {}
42- self . locations = {}
33+ def get_words ( self ) -> Tuple [ List [ str ], List [ int ]]:
34+ raise NotImplementedError
4335
44- self .collect_words ()
45- self .build_vocab ()
46- self .build_word_arrays ()
47- self .collect_pairs ()
36+ def tokenize (self , data : str ) -> List [str ]:
37+ raise NotImplementedError
4838
49- def learn (self , merges : int ):
50- for i in monit .iterate ('BPE' , merges ):
51- while True :
52- res = self .merge_pair ()
53- if res is not None :
54- break
39+
40+ class SourceCodeTokenizer (Tokenizer ):
41+ def __init__ (self ):
42+ self .words = {}
5543
5644 def add_word (self , word ):
5745 if not word :
@@ -62,28 +50,96 @@ def add_word(self, word):
6250 else :
6351 self .words [word ] += 1
6452
65- def collect_words (self ) :
53+ def tokenize (self , data : str ) -> List [ str ] :
6654 last_idx = 0
6755 is_id = False
56+ res = []
6857
69- for i , c in monit .enum ('Collect words' , self . data ):
58+ for i , c in monit .enum ('Collect words' , data ):
7059 if c in ID_CHARS :
7160 if not is_id :
72- self .add_word (self .data [last_idx :i ])
61+ if last_idx < i :
62+ res .append (data [last_idx :i ])
7363 last_idx = i
7464 is_id = True
7565 else :
7666 if is_id :
77- self .add_word (self .data [last_idx :i ])
67+ if last_idx < i :
68+ res .append (data [last_idx :i ])
7869 last_idx = i
7970 is_id = False
8071
81- self .add_word (self .data [last_idx :])
72+ if last_idx < len (data ):
73+ res .append (data [last_idx :])
74+
75+ return res
76+
77+ def collect_words (self , data : str ):
78+ last_idx = 0
79+ is_id = False
80+
81+ for i , c in monit .enum ('Collect words' , data ):
82+ if c in ID_CHARS :
83+ if not is_id :
84+ self .add_word (data [last_idx :i ])
85+ last_idx = i
86+ is_id = True
87+ else :
88+ if is_id :
89+ self .add_word (data [last_idx :i ])
90+ last_idx = i
91+ is_id = False
92+
93+ self .add_word (data [last_idx :])
94+
95+ def get_words (self ):
8296 words_list = [(f , w ) for w , f in self .words .items ()]
8397 words_list .sort (key = lambda x : - x [0 ])
8498
85- self .words_list = [w for _ , w in words_list ]
86- self .word_freq = [f for f , _ in words_list ]
99+ return [w for _ , w in words_list ], [f for f , _ in words_list ]
100+
101+
102+ class NoTokenizer (Tokenizer ):
103+ def __init__ (self ):
104+ self .data = ''
105+
106+ def collect_words (self , data ):
107+ self .data += data
108+
109+ def get_words (self ):
110+ return [self .data ], [1 ]
111+
112+ def tokenize (self , data : str ) -> List [str ]:
113+ return [data ]
114+
115+
116+ class BPELearner :
117+ def __init__ (self , words_list : List [str ], word_freq : List [int ]):
118+ self .words_list = words_list
119+ self .word_freq = word_freq
120+
121+ self .heap = []
122+ self .heap_modified = set ()
123+ self .char_itos = []
124+ self .char_stoi = {}
125+ self .bpe = []
126+ self .word_codes = []
127+ self .word_code_prev = []
128+ self .word_code_next = []
129+
130+ self .counts = {}
131+ self .locations = {}
132+
133+ self .build_vocab ()
134+ self .build_word_arrays ()
135+ self .collect_pairs ()
136+
137+ def learn (self , merges : int ):
138+ for i in monit .iterate ('BPE' , merges ):
139+ while True :
140+ res = self .merge_pair ()
141+ if res is not None :
142+ break
87143
88144 def build_vocab (self ):
89145 vocab = set ()
@@ -230,11 +286,14 @@ def main():
230286 with open (str (path ), 'r' ) as f :
231287 data = f .read ()[:100_000 ]
232288
233- bpe = BPELearner (data )
289+ tokenizer = SourceCodeTokenizer ()
290+ tokenizer .collect_words (data )
291+
292+ bpe = BPELearner (* tokenizer .get_words ())
234293 bpe .learn (1000 )
235294 print (len (bpe .bpe ))
236295 print (bpe .bpe_itos ()[len (bpe .char_itos ):])
237- print (len (bpe . data ), bpe .get_length ())
296+ print (len (data ), bpe .get_length ())
238297
239298
240299if __name__ == '__main__' :
0 commit comments