use std::cmp::Ordering; use std::collections::HashMap as StdHashMap; use dary_heap::OctonaryHeap; use fancy_regex::Regex; use pyo3::prelude::*; use ahash::{AHashMap, AHashSet}; use compact_str::CompactString; use rayon::prelude::*; // Default GPT-4 style regex pattern for splitting text const GPT4_PATTERN: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"; type Pair = (u32, u32); /// A Byte Pair Encoding tokenizer that matches the GPT-4 style implementation #[pyclass] pub struct Tokenizer { /// Maps pairs of token IDs to their merged token ID pub merges: StdHashMap, /// The regex pattern used for text splitting pub pattern: String, /// Compiled regex for efficiency compiled_pattern: Regex, } // ------------------------ internal helpers ------------------------ #[derive(Clone, Debug)] struct Word { ids: Vec, } impl Word { #[inline] fn new(ids: Vec) -> Self { Self { ids } } #[inline] fn pairs<'a>(&'a self) -> impl Iterator + 'a { self.ids.windows(2).map(|w| (w[0], w[1])) } /// Merge all non-overlapping occurrences of pair -> new_id. /// Returns a small Vec of local pair-count deltas for THIS word only: /// -1 for removed pairs, +1 for newly created pairs. /// /// NOTE: this version deliberately avoids a HashMap in the hot loop. fn merge_pair(&mut self, pair: Pair, new_id: u32) -> Vec<(Pair, i32)> { let (a, b) = pair; let n = self.ids.len(); if n < 2 { return Vec::new(); } let mut out: Vec = Vec::with_capacity(n); let mut deltas: Vec<(Pair, i32)> = Vec::with_capacity(6); let mut i = 0; while i < n { if i + 1 < n && self.ids[i] == a && self.ids[i + 1] == b { let left = out.last().copied(); let right = if i + 2 < n { Some(self.ids[i + 2]) } else { None }; // remove old pairs if let Some(x) = left { deltas.push(((x, a), -1)); deltas.push(((x, new_id), 1)); } deltas.push(((a, b), -1)); if let Some(y) = right { deltas.push(((b, y), -1)); deltas.push(((new_id, y), 1)); } // write merged token out.push(new_id); i += 2; // skip 'a' and 'b' } else { out.push(self.ids[i]); i += 1; } } self.ids = out; deltas } } #[derive(Debug, Eq)] struct MergeJob { pair: Pair, count: u64, /// set of word indices where this pair may occur and needs processing pos: AHashSet, } impl PartialEq for MergeJob { fn eq(&self, other: &Self) -> bool { self.count == other.count && self.pair == other.pair } } impl PartialOrd for MergeJob { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } impl Ord for MergeJob { fn cmp(&self, other: &Self) -> Ordering { // Max-heap by count; tie-break to ascending pair order (deterministic) if self.count != other.count { self.count.cmp(&other.count) } else { // ascending order on the pair when counts tie other.pair.cmp(&self.pair) } } } #[inline] fn count_pairs_parallel( words: &[Word], counts: &[i32], ) -> (AHashMap, AHashMap>) { words .par_iter() .enumerate() .map(|(i, w)| { let mut local_pc: AHashMap = AHashMap::new(); let mut local_wtu: AHashMap> = AHashMap::new(); if w.ids.len() >= 2 && counts[i] != 0 { for (a, b) in w.pairs() { *local_pc.entry((a, b)).or_default() += counts[i]; local_wtu.entry((a, b)).or_default().insert(i); } } (local_pc, local_wtu) }) .reduce( || (AHashMap::new(), AHashMap::new()), |(mut acc_pc, mut acc_wtu), (pc, wtu)| { for (k, v) in pc { *acc_pc.entry(k).or_default() += v; } for (k, s) in wtu { acc_wtu.entry(k).or_default().extend(s); } (acc_pc, acc_wtu) }, ) } // ------------------------ END helpers ------------------------ impl Tokenizer { /// Core incremental BPE training given unique words and their counts. /// `words`: one entry per unique chunk (Vec of token-ids/bytes). /// `counts`: same length as `words`, count per chunk. fn train_core_incremental(&mut self, mut words: Vec, counts: Vec, vocab_size: u32) { assert!(vocab_size >= 256, "vocab_size must be at least 256"); let num_merges = vocab_size - 256; log::info!("Starting BPE training: {} merges to compute", num_merges); self.merges.clear(); // ---- Initial pair_counts and where_to_update (parallel) ---- log::info!("Computing initial pair counts from {} unique sequences", words.len()); let (mut pair_counts, mut where_to_update) = count_pairs_parallel(&words, &counts); // ---- Build heap ---- log::info!("Building heap with {} unique pairs", pair_counts.len()); let mut heap = OctonaryHeap::with_capacity(pair_counts.len()); for (pair, pos) in where_to_update.drain() { let c = *pair_counts.get(&pair).unwrap_or(&0); if c > 0 { heap.push(MergeJob { pair, count: c as u64, pos, }); } } // ---- Merge loop ---- log::info!("Starting merge loop"); let mut merges_done = 0u32; let mut last_log_percent = 0u32; while merges_done < num_merges { let Some(mut top) = heap.pop() else { break; }; // Lazy refresh let current = *pair_counts.get(&top.pair).unwrap_or(&0); if top.count != current as u64 { top.count = current as u64; if top.count > 0 { heap.push(top); } continue; } if top.count == 0 { break; } // Record merge let new_id = 256 + merges_done; self.merges.insert(top.pair, new_id); // Merge this pair in all words where it occurs let mut local_pos_updates: AHashMap> = AHashMap::new(); for &word_idx in &top.pos { // Apply merge to this word and collect pair-count deltas let changes = words[word_idx].merge_pair(top.pair, new_id); // Update global pair counts based on this word's count for (pair, delta) in changes { let delta_total = delta * counts[word_idx]; if delta_total != 0 { *pair_counts.entry(pair).or_default() += delta_total; if delta > 0 { local_pos_updates.entry(pair).or_default().insert(word_idx); } } } } // Add the updated pair counts back to the heap for (pair, pos) in local_pos_updates { let cnt = *pair_counts.get(&pair).unwrap_or(&0); if cnt > 0 { heap.push(MergeJob { pair, count: cnt as u64, pos, }); } } merges_done += 1; // Log progress every 1% let current_percent = (merges_done * 100) / num_merges; if current_percent > last_log_percent { log::info!( "Progress: {}% ({}/{} merges) - Last merge: {:?} -> {} (frequency: {})", current_percent, merges_done, num_merges, top.pair, new_id, top.count ); last_log_percent = current_percent; } } log::info!("Finished training: {} merges completed", merges_done); } } /// Public methods for the Tokenizer class that will be exposed to Python. #[pymethods] impl Tokenizer { /// Create a new Tokenizer #[new] pub fn new() -> Self { Self { merges: StdHashMap::new(), pattern: String::new(), compiled_pattern: Regex::new("").expect("Empty regex should be valid"), } } /// Train from a streaming iterator (parallel ingestion). /// We refill a Rust Vec buffer under the GIL, then release the GIL /// to do the heavy splitting and counting **in parallel** with rayon. #[pyo3(signature = (iterator, vocab_size, buffer_size=8192, pattern=None))] #[pyo3(text_signature = "(self, iterator, vocab_size, buffer_size=8192, pattern=None)")] pub fn train_from_iterator( &mut self, py: pyo3::Python<'_>, iterator: &pyo3::Bound<'_, pyo3::PyAny>, vocab_size: u32, buffer_size: usize, pattern: Option, ) -> PyResult<()> { // Use provided pattern or default to GPT-4 pattern let pattern_str = pattern.unwrap_or_else(|| GPT4_PATTERN.to_string()); // Update the stored pattern and compile it self.pattern = pattern_str.clone(); self.compiled_pattern = Regex::new(&pattern_str) .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Invalid regex pattern: {}", e)))?; // Prepare a true Python iterator object let py_iter: pyo3::Py = unsafe { pyo3::Bound::from_borrowed_ptr_or_err(py, pyo3::ffi::PyObject_GetIter(iterator.as_ptr()))? .into() }; // Global chunk counts let mut counts: AHashMap = AHashMap::new(); // Temporary buffer we refill under the GIL let mut buf: Vec = Vec::with_capacity(buffer_size); log::info!("Processing sequences from iterator (buffer_size: {})", buffer_size); let mut total_sequences = 0u64; // Helper: refill `buf` with up to `buffer_size` strings from the Python iterator. // Returns Ok(true) if the iterator is exhausted, Ok(false) otherwise. let refill = |buf: &mut Vec| -> PyResult { pyo3::Python::with_gil(|py| { buf.clear(); let it = py_iter.bind(py); loop { if buf.len() >= buffer_size { return Ok(false); } // next(it) let next_obj = unsafe { pyo3::Bound::from_owned_ptr_or_opt(py, pyo3::ffi::PyIter_Next(it.as_ptr())) }; match next_obj { Some(obj) => { let s: String = obj.extract()?; buf.push(s); } None => { if pyo3::PyErr::occurred(py) { return Err(pyo3::PyErr::fetch(py)); } else { return Ok(true); // exhausted } } } } }) }; // Stream ingestion loop: refill under GIL, process without GIL (parallel) loop { let exhausted = refill(&mut buf)?; if buf.is_empty() && exhausted { break; } total_sequences += buf.len() as u64; let pattern = self.compiled_pattern.clone(); let local: AHashMap = py.allow_threads(|| { buf.par_iter() .map(|s| { let mut m: AHashMap = AHashMap::new(); for mat in pattern.find_iter(s) { let piece = mat.expect("regex match failed").as_str(); *m.entry(CompactString::from(piece)).or_default() += 1; } m }) .reduce( || AHashMap::new(), |mut a, b| { for (k, v) in b { *a.entry(k).or_default() += v; } a }, ) }); // Merge local into global (single-threaded) for (k, v) in local { *counts.entry(k).or_default() += v; } if exhausted { break; } } log::info!("Processed {} sequences total, {} unique", total_sequences, counts.len()); // Materialize words & counts let mut words = Vec::with_capacity(counts.len()); let mut cvec = Vec::with_capacity(counts.len()); for (chunk, c) in counts.into_iter() { words.push(Word::new(chunk.as_bytes().iter().map(|&b| b as u32).collect())); cvec.push(c); } self.train_core_incremental(words, cvec, vocab_size); Ok(()) } /// Return the regex pattern pub fn get_pattern(&self) -> String { self.pattern.clone() } /// Return the mergeable ranks (token bytes -> token id / rank) pub fn get_mergeable_ranks(&self) -> Vec<(Vec, u32)> { let mut mergeable_ranks = Vec::new(); // Build vocabulary incrementally from low to high token IDs let mut token_bytes: Vec> = (0..256_u32).map(|i| vec![i as u8]).collect(); for (i, bytes) in token_bytes.iter().enumerate() { mergeable_ranks.push((bytes.clone(), i as u32)); } // Sort merges by token id (so we can reconstruct bytes progressively) let mut sorted_merges: Vec<_> = self.merges.iter().collect(); sorted_merges.sort_by_key(|&(_, &token_id)| token_id); for (&pair, &merged_id) in sorted_merges { let (left, right) = pair; let mut merged_bytes = token_bytes[left as usize].clone(); merged_bytes.extend(&token_bytes[right as usize]); if token_bytes.len() <= merged_id as usize { token_bytes.resize(merged_id as usize + 1, Vec::new()); } token_bytes[merged_id as usize] = merged_bytes.clone(); mergeable_ranks.push((merged_bytes, merged_id)); } mergeable_ranks } /// Encode a string into token IDs pub fn encode(&self, text: &str) -> Vec { let mut all_ids = Vec::new(); // Split text using the regex pattern for m in self.compiled_pattern.find_iter(text) { let chunk = m.expect("regex match failed").as_str(); // Convert chunk to bytes then to u32 IDs let mut ids: Vec = chunk.bytes().map(|b| b as u32).collect(); // Apply merges iteratively while ids.len() >= 2 { // Find the best pair to merge let mut best_pair: Option<(usize, Pair, u32)> = None; for i in 0..ids.len() - 1 { let pair: Pair = (ids[i], ids[i + 1]); if let Some(&new_id) = self.merges.get(&pair) { if best_pair.is_none() || new_id < best_pair.unwrap().2 { best_pair = Some((i, pair, new_id)); } } } // If we found a pair to merge, apply it if let Some((idx, _pair, new_id)) = best_pair { ids[idx] = new_id; ids.remove(idx + 1); } else { // No more merges possible break; } } all_ids.extend(ids); } all_ids } } #[pymodule] fn rustbpe(m: &Bound<'_, PyModule>) -> PyResult<()> { pyo3_log::init(); // forwards Rust `log` to Python's `logging` m.add_class::()?; Ok(()) }