1- import string
1+ from functools import lru_cache
22from heapq import heappush , heappop
3- from typing import List , Tuple
3+ from typing import List
44
55from labml import lab , monit
6+ from labml .utils .cache import cache_set
7+ from python_autocomplete .dataset import Tokenizer
8+ from python_autocomplete .dataset .break_words import SourceCodeTokenizer
69
7- ID_CHARS = set (string .ascii_letters + string .digits + '_' )
810
11+ class BPE (Tokenizer ):
12+ def __init__ (self , bpe_en_de : 'BPEEnDe' , word_tokenizer ):
13+ self .bpe = bpe_en_de
14+ self .word_tokenizer = word_tokenizer
15+ self .is_trained = True
916
10- class BPE :
11- def __init__ (self ):
12- self .char_itos = []
13- self .char_stoi = {}
14- self .bpe_itos = []
15- self .bpe = []
16- self .common = {}
17+ @property
18+ def n_tokens (self ):
19+ return len (self .bpe .bpe )
1720
18- self .bpe_itos = self .calc_bpe_itos ()
21+ @property
22+ def itos (self ):
23+ return self .bpe .bpe_itos
1924
20- def to_char_stoi (self , w : str ):
21- return [self .char_stoi [c ] for c in w ]
25+ @property
26+ def stoi (self ):
27+ return self .bpe .bpe_stoi
2228
23- def calc_bpe_itos (self ):
24- itos = list (self .char_itos )
25- itos += [itos [p1 ] + itos [p2 ] for p1 , p2 in self .bpe [len (self .char_itos ):]]
26- return itos
29+ def encode (self , data : str , * , is_silent : bool = True ):
30+ words = self .word_tokenizer .tokenize (data , is_silent = is_silent )
2731
32+ res = []
33+ for w in monit .iterate ('Encode words' , words , is_silent = is_silent ):
34+ res += self .bpe .encode (w )
2835
29- class Tokenizer :
30- def collect_words (self , data : str ):
31- raise NotImplementedError
36+ return res
3237
33- def get_words (self ) -> Tuple [List [str ], List [int ]]:
34- raise NotImplementedError
38+ def __call__ (self , data : str ):
39+ encoded = self .encode (data )
40+ return [self .itos [c ] for c in encoded ]
3541
36- def tokenize (self , data : str ) -> List [str ]:
37- raise NotImplementedError
42+ def rstrip (self , data : str ):
43+ words = self .word_tokenizer .tokenize (data , is_silent = True )
44+ words = words [:- 1 ]
45+ res = []
46+ for w in words :
47+ res += self .bpe .encode (w )
3848
49+ return '' .join (words ), res
3950
40- class SourceCodeTokenizer (Tokenizer ):
41- def __init__ (self ):
42- self .words = {}
4351
44- def add_word (self , word ):
45- if not word :
46- return
52+ class _BPEEncoder :
53+ def __init__ (self , pairs ):
54+ self .pairs = pairs
55+ self .codes = []
56+ self .next_idx = []
57+ self .prev_idx = []
58+ self .heap = []
4759
48- if word not in self .words :
49- self .words [word ] = 1
50- else :
51- self .words [word ] += 1
60+ def encode (self , codes : List [int ]):
61+ self .codes = codes
62+ self .next_idx = BPELearner .default_next_pointers (len (codes ))
63+ self .prev_idx = BPELearner .default_prev_pointers (len (codes ))
64+ self .heap = []
5265
53- def tokenize (self , data : str ) -> List [str ]:
54- last_idx = 0
55- is_id = False
56- res = []
66+ for i in range (len (self .codes ) - 1 ):
67+ self .add_pair ((self .codes [i ], self .codes [i + 1 ]), i )
5768
58- for i , c in monit .enum ('Collect words' , data ):
59- if c in ID_CHARS :
60- if not is_id :
61- if last_idx < i :
62- res .append (data [last_idx :i ])
63- last_idx = i
64- is_id = True
65- else :
66- if is_id :
67- if last_idx < i :
68- res .append (data [last_idx :i ])
69- last_idx = i
70- is_id = False
71-
72- if last_idx < len (data ):
73- res .append (data [last_idx :])
69+ while self .heap :
70+ _ , idx , pair = heappop (self .heap )
71+ self .merge (idx , pair )
7472
75- return res
73+ return [ c for c in self . codes if c != - 1 ]
7674
77- def collect_words (self , data : str ):
78- last_idx = 0
79- is_id = False
75+ def merge (self , p2 , pair ):
76+ p3 = self .next_idx [p2 ]
77+
78+ if p3 == - 1 or pair [0 ] != self .codes [p2 ] or pair [1 ] != self .codes [p3 ]:
79+ return
8080
81- for i , c in monit .enum ('Collect words' , data ):
82- if c in ID_CHARS :
83- if not is_id :
84- self .add_word (data [last_idx :i ])
85- last_idx = i
86- is_id = True
87- else :
88- if is_id :
89- self .add_word (data [last_idx :i ])
90- last_idx = i
91- is_id = False
81+ self .codes [p2 ] = self .pairs [pair ]
82+ self .codes [p3 ] = - 1
83+ p1 = self .prev_idx [p2 ]
84+ p4 = self .next_idx [p3 ]
9285
93- self .add_word (data [last_idx :])
86+ if p1 != - 1 :
87+ self .add_pair ((self .codes [p1 ], self .codes [p2 ]), p1 )
88+ self .next_idx [p2 ] = p4
89+ if p4 != - 1 :
90+ self .prev_idx [p4 ] = p2
91+ self .add_pair ((self .codes [p2 ], self .codes [p4 ]), p2 )
9492
95- def get_words (self ):
96- words_list = [( f , w ) for w , f in self .words . items ()]
97- words_list . sort ( key = lambda x : - x [ 0 ])
93+ def add_pair (self , pair , idx ):
94+ if pair not in self .pairs :
95+ return
9896
99- return [ w for _ , w in words_list ], [ f for f , _ in words_list ]
97+ heappush ( self . heap , ( self . pairs [ pair ], idx , pair ))
10098
10199
102- class NoTokenizer ( Tokenizer ) :
100+ class BPEEnDe :
103101 def __init__ (self ):
104- self .data = ''
102+ self .char_itos = []
103+ self .char_stoi = {}
104+ self .bpe = []
105+ self .popular_words = {}
106+
107+ self .bpe_itos = []
108+ self .bpe_stoi = {}
109+ self .pairs = {}
110+ self .encoder = None
111+
112+ def load (self , char_itos , char_stoi , bpe ):
113+ self .char_itos = char_itos
114+ self .char_stoi = char_stoi
115+ self .bpe = bpe
116+
117+ self .calc ()
118+
119+ def set_popular_words (self , popular_words ):
120+ self .popular_words = popular_words
121+
122+ def calc (self ):
123+ self .bpe_itos = self .calc_bpe_itos ()
124+ self .bpe_stoi = {s : i for i , s in enumerate (self .bpe_itos )}
125+ self .pairs = {(p [0 ], p [1 ]): c for c , p in enumerate (self .bpe ) if not isinstance (p , int )}
105126
106- def collect_words (self , data ):
107- self .data += data
127+ self .encoder = _BPEEncoder (self .pairs )
108128
109- def get_words (self ):
110- return [self .data ], [ 1 ]
129+ def to_char_stoi (self , w : str ):
130+ return [self .char_stoi [ c ] for c in w ]
111131
112- def tokenize (self , data : str ) -> List [str ]:
113- return [data ]
132+ def calc_bpe_itos (self ):
133+ itos = list (self .char_itos )
134+ for p1 , p2 in self .bpe [len (self .char_itos ):]:
135+ itos .append (itos [p1 ] + itos [p2 ])
136+ return itos
137+
138+ @lru_cache (1024 )
139+ def encode (self , word : str ):
140+ if word in self .popular_words :
141+ return self .popular_words [word ]
142+
143+ return self .encoder .encode ([self .char_stoi [c ] for c in word if c in self .char_stoi ])
114144
115145
116146class BPELearner :
@@ -284,7 +314,7 @@ def main():
284314 path = lab .get_data_path () / 'train.py'
285315
286316 with open (str (path ), 'r' ) as f :
287- data = f .read ()[: 100_000 ]
317+ data = f .read ()
288318
289319 tokenizer = SourceCodeTokenizer ()
290320 tokenizer .collect_words (data )
@@ -295,6 +325,15 @@ def main():
295325 print (bpe .bpe_itos ()[len (bpe .char_itos ):])
296326 print (len (data ), bpe .get_length ())
297327
328+ cache_set ('bpe' , {
329+ 'char_itos' : bpe .char_itos ,
330+ 'char_stoi' : bpe .char_stoi ,
331+ 'bpe' : bpe .bpe
332+ })
333+
334+ bpe_en_de = BPEEnDe ()
335+ bpe_en_de .load (bpe .char_itos , bpe .char_stoi , bpe .bpe )
336+
298337
299338if __name__ == '__main__' :
300339 main ()
0 commit comments