anycoder-1970d9cb / utils.py
Angwolfrust's picture
Upload folder using huggingface_hub
9e5cadc verified
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import functools
# Define a subset of popular languages mapped to FLORES-200 codes for better UX.
# NLLB supports 200+, but a dropdown of 200 items can be unwieldy.
# Codes reference: https://github.com/facebookresearch/flores/blob/main/flores200/README.md
LANGUAGE_CODES = {
"English": "eng_Latn",
"French": "fra_Latn",
"Spanish": "spa_Latn",
"German": "deu_Latn",
"Chinese (Simplified)": "zho_Hans",
"Chinese (Traditional)": "zho_Hant",
"Hindi": "hin_Deva",
"Arabic": "arb_Arab",
"Russian": "rus_Cyrl",
"Portuguese": "por_Latn",
"Japanese": "jpn_Jpan",
"Korean": "kor_Hang",
"Italian": "ita_Latn",
"Dutch": "nld_Latn",
"Turkish": "tur_Latn",
"Vietnamese": "vie_Latn",
"Indonesian": "ind_Latn",
"Persian": "pes_Arab",
"Polish": "pol_Latn",
"Ukrainian": "ukr_Cyrl",
"Swahili": "swh_Latn",
"Urdu": "urd_Arab",
"Bengali": "ben_Beng",
"Tamil": "tam_Taml"
}
MODEL_NAME = "facebook/nllb-200-distilled-600M"
_model = None
_tokenizer = None
def get_device():
"""Determines the best available device."""
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
return "cpu"
def load_model():
"""
Loads the model and tokenizer lazily (singleton pattern).
"""
global _model, _tokenizer
if _model is None:
print(f"Loading {MODEL_NAME}...")
device = get_device()
_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)
print("Model loaded successfully.")
return _model, _tokenizer
def translate_text(text, src_lang_name, tgt_lang_name):
"""
Performs the translation using NLLB.
"""
if not text:
return ""
try:
model, tokenizer = load_model()
device = model.device
# Get NLLB specific codes
src_code = LANGUAGE_CODES.get(src_lang_name, "eng_Latn")
tgt_code = LANGUAGE_CODES.get(tgt_lang_name, "fra_Latn")
# Prepare inputs
tokenizer.src_lang = src_code
inputs = tokenizer(text, return_tensors="pt").to(device)
# Generate translation
# forced_bos_token_id forces the model to start generating in the target language
generated_tokens = model.generate(
**inputs,
forced_bos_token_id=tokenizer.lang_code_to_id[tgt_code],
max_length=200
)
# Decode output
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
return result
except Exception as e:
return f"Error during translation: {str(e)}"