Spaces:
Runtime error
Runtime error
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| import torch | |
| import os | |
| # Disable numba caching | |
| os.environ["NUMBA_CACHE_DIR"] = "/tmp/numba_cache" | |
| os.environ["NUMBA_DISABLE_JIT"] = "1" | |
| def nllb(): | |
| """ | |
| Load and return the NLLB (No Language Left Behind) model and tokenizer. | |
| This function loads the NLLB-200-distilled-1.3B model and tokenizer from Hugging Face's Transformers library. | |
| The model is configured to use a GPU if available, otherwise it defaults to CPU. | |
| Returns: | |
| tuple: A tuple containing the loaded model and tokenizer. | |
| - model (transformers.AutoModelForSeq2SeqLM): The loaded NLLB model. | |
| - tokenizer (transformers.AutoTokenizer): The loaded tokenizer. | |
| Example usage: | |
| model, tokenizer = nllb() | |
| """ | |
| #device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load the tokenizer and model | |
| # Set Hugging Face cache directory | |
| # Ensure the cache directory exists and has the correct permissions | |
| os.environ['HF_HOME'] = '/app/cache/huggingface' | |
| os.environ['TRANSFORMERS_CACHE'] = '/app/cache/huggingface' | |
| # Load models | |
| tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-1.3B") | |
| model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-1.3B").to(device) | |
| return model, tokenizer | |
| def nllb_translate(model, tokenizer, article, language): | |
| """ | |
| Translate an article using the NLLB model and tokenizer. | |
| Args: | |
| model (transformers.AutoModelForSeq2SeqLM): The NLLB model to use for translation. | |
| Example: model, tokenizer = nllb() | |
| tokenizer (transformers.AutoTokenizer): The tokenizer to use with the NLLB model. | |
| Example: model, tokenizer = nllb() | |
| article (str): The article text to be translated. | |
| Example: "This is a sample article." | |
| language (str): The target language for translation. Must be either 'spanish' or 'english'. | |
| Example: "spanish" | |
| Returns: | |
| str: The translated text. | |
| Example: "Este es un artículo de muestra." | |
| """ | |
| try: | |
| # Tokenize the text | |
| inputs = tokenizer(article, return_tensors="pt") | |
| # Move the tokenized inputs to the same device as the model | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| if language == "es": | |
| translated_tokens = model.generate( | |
| **inputs, forced_bos_token_id=tokenizer.lang_code_to_id["spa_Latn"], max_length=30 | |
| ) | |
| elif language == "en": | |
| translated_tokens = model.generate( | |
| **inputs, forced_bos_token_id=tokenizer.lang_code_to_id["eng_Latn"], max_length=30 | |
| ) | |
| else: | |
| raise ValueError("Unsupported language. Use 'es' or 'en'.") | |
| # Decode the translation | |
| text = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] | |
| return text | |
| except Exception as e: | |
| print(f"Error during translation: {e}") | |
| return "Translation failed" | |