diff --git a/crates/bpe-openai/src/lib.rs b/crates/bpe-openai/src/lib.rs index 51f514f8..7a0aad48 100644 --- a/crates/bpe-openai/src/lib.rs +++ b/crates/bpe-openai/src/lib.rs @@ -4,7 +4,6 @@ use bpe::byte_pair_encoding::BytePairEncoding; use either::Either; use regex_automata::{ meta::{BuildError, Regex}, - util::captures::Captures, Anchored, Input, }; @@ -177,7 +176,6 @@ impl Pretokenizer { lookahead: &self.lookahead, text, last: 0, - caps: Captures::matches(self.pat.group_info().clone()), } } } @@ -195,7 +193,6 @@ struct Splits<'a> { lookahead: &'a [bool], text: &'a str, last: usize, - caps: Captures, } impl<'a> Iterator for Splits<'a> { @@ -203,9 +200,7 @@ impl<'a> Iterator for Splits<'a> { fn next(&mut self) -> Option { let input = Input::new(&self.text[self.last..]).anchored(Anchored::Yes); - self.caps.clear(); - self.pat.captures(input, &mut self.caps); - let m = self.caps.get_match()?; + let m = self.pat.find(input)?; let start = self.last; let mut end = self.last + m.range().end; if self.lookahead[m.pattern().as_usize()] { diff --git a/crates/bpe/benchmarks/performance.rs b/crates/bpe/benchmarks/performance.rs index 9b2a257a..f52941ca 100644 --- a/crates/bpe/benchmarks/performance.rs +++ b/crates/bpe/benchmarks/performance.rs @@ -244,12 +244,32 @@ fn worstcase_comparison_benchmark(c: &mut Criterion) { } } +fn pretokenization_benchmark(c: &mut Criterion) { + for (name, tok, _, _) in TOKENIZERS.iter() { + let text = create_test_string(&tok.bpe, 80_000); + + let mut group = c.benchmark_group(format!("pretokenization-{name}")); + group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); + for bytes in [10, 100, 1000, 10000] { + group.throughput(criterion::Throughput::Bytes(bytes as u64)); + group.bench_with_input(BenchmarkId::new("split", bytes), &bytes, |b, bytes| { + b.iter_batched( + || select_test_string(&text, *bytes), + |text| tok.split(text).count(), + criterion::BatchSize::SmallInput, + ) + }); + } + group.finish(); + } +} + criterion_group!( name = benches; config = Criterion::default() .warm_up_time(Duration::from_millis(500)) .measurement_time(Duration::from_millis(4000)) .nresamples(1000); - targets = counting_benchmark, encoding_benchmark, appending_benchmark, comparison_benchmark, worstcase_comparison_benchmark + targets = counting_benchmark, encoding_benchmark, appending_benchmark, pretokenization_benchmark, comparison_benchmark, worstcase_comparison_benchmark ); criterion_main!(benches);