Spaces:
No application file
No application file
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import functools | |
| # Define a subset of popular languages mapped to FLORES-200 codes for better UX. | |
| # NLLB supports 200+, but a dropdown of 200 items can be unwieldy. | |
| # Codes reference: https://github.com/facebookresearch/flores/blob/main/flores200/README.md | |
| LANGUAGE_CODES = { | |
| "English": "eng_Latn", | |
| "French": "fra_Latn", | |
| "Spanish": "spa_Latn", | |
| "German": "deu_Latn", | |
| "Chinese (Simplified)": "zho_Hans", | |
| "Chinese (Traditional)": "zho_Hant", | |
| "Hindi": "hin_Deva", | |
| "Arabic": "arb_Arab", | |
| "Russian": "rus_Cyrl", | |
| "Portuguese": "por_Latn", | |
| "Japanese": "jpn_Jpan", | |
| "Korean": "kor_Hang", | |
| "Italian": "ita_Latn", | |
| "Dutch": "nld_Latn", | |
| "Turkish": "tur_Latn", | |
| "Vietnamese": "vie_Latn", | |
| "Indonesian": "ind_Latn", | |
| "Persian": "pes_Arab", | |
| "Polish": "pol_Latn", | |
| "Ukrainian": "ukr_Cyrl", | |
| "Swahili": "swh_Latn", | |
| "Urdu": "urd_Arab", | |
| "Bengali": "ben_Beng", | |
| "Tamil": "tam_Taml" | |
| } | |
| MODEL_NAME = "facebook/nllb-200-distilled-600M" | |
| _model = None | |
| _tokenizer = None | |
| def get_device(): | |
| """Determines the best available device.""" | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| elif torch.backends.mps.is_available(): | |
| return "mps" | |
| return "cpu" | |
| def load_model(): | |
| """ | |
| Loads the model and tokenizer lazily (singleton pattern). | |
| """ | |
| global _model, _tokenizer | |
| if _model is None: | |
| print(f"Loading {MODEL_NAME}...") | |
| device = get_device() | |
| _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device) | |
| print("Model loaded successfully.") | |
| return _model, _tokenizer | |
| def translate_text(text, src_lang_name, tgt_lang_name): | |
| """ | |
| Performs the translation using NLLB. | |
| """ | |
| if not text: | |
| return "" | |
| try: | |
| model, tokenizer = load_model() | |
| device = model.device | |
| # Get NLLB specific codes | |
| src_code = LANGUAGE_CODES.get(src_lang_name, "eng_Latn") | |
| tgt_code = LANGUAGE_CODES.get(tgt_lang_name, "fra_Latn") | |
| # Prepare inputs | |
| tokenizer.src_lang = src_code | |
| inputs = tokenizer(text, return_tensors="pt").to(device) | |
| # Generate translation | |
| # forced_bos_token_id forces the model to start generating in the target language | |
| generated_tokens = model.generate( | |
| **inputs, | |
| forced_bos_token_id=tokenizer.lang_code_to_id[tgt_code], | |
| max_length=200 | |
| ) | |
| # Decode output | |
| result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
| return result | |
| except Exception as e: | |
| return f"Error during translation: {str(e)}" |