import gradio as gr import json, pickle, re from collections import Counter from huggingface_hub import hf_hub_download import unicodedata REPO_ID = "snskrt/sanskrit-morpheme-tokenizer" # change this class SanskritMorphemeTokenizer: def __init__(self): self.token_to_id = {} self.id_to_token = {} self.morpheme_vocab = set() self.morpheme_freq = Counter() self.unk_token = "[UNK]" def clean_token(self, token: str): token = unicodedata.normalize("NFC", token) token = re.sub(r'[*।॥०-९\d]+', '', token) token = token.strip() return token if token else None def _segment_word(self, word: str): """ DP-based segmenter: - Minimizes #UNK pieces - Then minimizes total pieces - Then prefers longer known morphemes """ if not word: return [self.unk_token] if word in self.morpheme_vocab: return [word] for i in range(len(word), 0, -1): prefix = word[:i] if prefix in self.morpheme_vocab: remaining = word[i:] return [prefix] + (self._segment_word(remaining) if remaining else []) return [self.unk_token] from functools import lru_cache # optional: cap how far we look ahead; adjust if your morphemes are very long max_morph_len = min(30, len(word)) @lru_cache(None) def best(i: int): # returns (unk_count, pieces_count, -avg_known_len, pieces_list) if i == len(word): return (0, 0, 0.0, []) best_tuple = (10**9, 10**9, 0.0, [self.unk_token]) # big sentinel # try all prefixes starting at i for j in range(i + 1, min(len(word), i + max_morph_len) + 1): piece = word[i:j] is_known = piece in self.morpheme_vocab # cost for this piece piece_unk = 0 if is_known else 1 # recurse for the remainder tail = best(j) unk_count = piece_unk + tail[0] pieces_count = 1 + tail[1] # score tiebreak: prefer longer known pieces known_len = len(piece) if is_known else 0 # for averaging, combine with tail's average (stored as negative) # to keep scoring monotonic, we’ll compute a simple total-known-len total_known_len = known_len + (-tail[2]) * max(1, tail[1]) # invert back # pack a comparable tuple: # 1) fewer UNKs, 2) fewer pieces, 3) longer total known length candidate = (unk_count, pieces_count, - (total_known_len / pieces_count), [piece] + tail[3]) if candidate < best_tuple: best_tuple = candidate return best_tuple return best(0)[3] def tokenize(self, text: str): tokens = [] for w in text.split(): cw = self.clean_token(w) if not cw: continue if cw in self.morpheme_vocab: tokens.append(cw) else: tokens.extend(self._segment_word(cw)) return tokens def encode(self, text: str): return [self.token_to_id.get(t, self.token_to_id.get(self.unk_token)) for t in self.tokenize(text)] def decode(self, ids): return " ".join(self.id_to_token.get(i, self.unk_token) for i in ids) def load_from_hub(self, repo_id): vocab_fp = hf_hub_download(repo_id, "vocab.json") freq_fp = hf_hub_download(repo_id, "morpheme_freq.pkl") cfg_fp = hf_hub_download(repo_id, "config.json") with open(vocab_fp, "r", encoding="utf-8") as f: self.token_to_id = json.load(f) self.id_to_token = {int(i): tok for tok, i in self.token_to_id.items()} self.morpheme_vocab = set(self.token_to_id.keys()) with open(freq_fp, "rb") as f: self.morpheme_freq = Counter(pickle.load(f)) tokenizer = SanskritMorphemeTokenizer() tokenizer.load_from_hub(REPO_ID) def run(text): tokens = tokenizer.tokenize(text) ids = tokenizer.encode(text) decoded = tokenizer.decode(ids) return tokens, ids, decoded demo = gr.Interface( fn=run, inputs=gr.Textbox(label="Input Sanskrit Text", lines=2), outputs=[gr.JSON(label="Tokens"), gr.JSON(label="Token IDs"), gr.Textbox(label="Decoded")] ) demo.launch()