token / app.py
13Aluminium's picture
Update app.py
5439334 verified
import gradio as gr
import json, pickle, re
from collections import Counter
from huggingface_hub import hf_hub_download
import unicodedata
REPO_ID = "snskrt/sanskrit-morpheme-tokenizer" # change this
class SanskritMorphemeTokenizer:
def __init__(self):
self.token_to_id = {}
self.id_to_token = {}
self.morpheme_vocab = set()
self.morpheme_freq = Counter()
self.unk_token = "[UNK]"
def clean_token(self, token: str):
token = unicodedata.normalize("NFC", token)
token = re.sub(r'[*।॥०-९\d]+', '', token)
token = token.strip()
return token if token else None
def _segment_word(self, word: str):
"""
DP-based segmenter:
- Minimizes #UNK pieces
- Then minimizes total pieces
- Then prefers longer known morphemes
"""
if not word:
return [self.unk_token]
if word in self.morpheme_vocab:
return [word]
for i in range(len(word), 0, -1):
prefix = word[:i]
if prefix in self.morpheme_vocab:
remaining = word[i:]
return [prefix] + (self._segment_word(remaining) if remaining else [])
return [self.unk_token]
from functools import lru_cache
# optional: cap how far we look ahead; adjust if your morphemes are very long
max_morph_len = min(30, len(word))
@lru_cache(None)
def best(i: int):
# returns (unk_count, pieces_count, -avg_known_len, pieces_list)
if i == len(word):
return (0, 0, 0.0, [])
best_tuple = (10**9, 10**9, 0.0, [self.unk_token]) # big sentinel
# try all prefixes starting at i
for j in range(i + 1, min(len(word), i + max_morph_len) + 1):
piece = word[i:j]
is_known = piece in self.morpheme_vocab
# cost for this piece
piece_unk = 0 if is_known else 1
# recurse for the remainder
tail = best(j)
unk_count = piece_unk + tail[0]
pieces_count = 1 + tail[1]
# score tiebreak: prefer longer known pieces
known_len = len(piece) if is_known else 0
# for averaging, combine with tail's average (stored as negative)
# to keep scoring monotonic, we’ll compute a simple total-known-len
total_known_len = known_len + (-tail[2]) * max(1, tail[1]) # invert back
# pack a comparable tuple:
# 1) fewer UNKs, 2) fewer pieces, 3) longer total known length
candidate = (unk_count, pieces_count, - (total_known_len / pieces_count), [piece] + tail[3])
if candidate < best_tuple:
best_tuple = candidate
return best_tuple
return best(0)[3]
def tokenize(self, text: str):
tokens = []
for w in text.split():
cw = self.clean_token(w)
if not cw:
continue
if cw in self.morpheme_vocab:
tokens.append(cw)
else:
tokens.extend(self._segment_word(cw))
return tokens
def encode(self, text: str):
return [self.token_to_id.get(t, self.token_to_id.get(self.unk_token)) for t in self.tokenize(text)]
def decode(self, ids):
return " ".join(self.id_to_token.get(i, self.unk_token) for i in ids)
def load_from_hub(self, repo_id):
vocab_fp = hf_hub_download(repo_id, "vocab.json")
freq_fp = hf_hub_download(repo_id, "morpheme_freq.pkl")
cfg_fp = hf_hub_download(repo_id, "config.json")
with open(vocab_fp, "r", encoding="utf-8") as f:
self.token_to_id = json.load(f)
self.id_to_token = {int(i): tok for tok, i in self.token_to_id.items()}
self.morpheme_vocab = set(self.token_to_id.keys())
with open(freq_fp, "rb") as f:
self.morpheme_freq = Counter(pickle.load(f))
tokenizer = SanskritMorphemeTokenizer()
tokenizer.load_from_hub(REPO_ID)
def run(text):
tokens = tokenizer.tokenize(text)
ids = tokenizer.encode(text)
decoded = tokenizer.decode(ids)
return tokens, ids, decoded
demo = gr.Interface(
fn=run,
inputs=gr.Textbox(label="Input Sanskrit Text", lines=2),
outputs=[gr.JSON(label="Tokens"), gr.JSON(label="Token IDs"), gr.Textbox(label="Decoded")]
)
demo.launch()