Spaces:
Runtime error
Runtime error
| import requests | |
| import os | |
| import fasttext | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
| import torch | |
| title = "Community Tab Language Detection & Translation" | |
| description = """ | |
| When comments are created in the community tab, detect the language of the content. | |
| Then, if the detected language is different from the user's language, display an option to translate it. | |
| """ | |
| LANG_ID_API_URL = "https://q5esh83u7boq5qwd.us-east-1.aws.endpoints.huggingface.cloud" | |
| ACCESS_TOKEN = os.environ.get("ACCESS_TOKEN") | |
| headers = {"Authorization": f"Bearer {ACCESS_TOKEN}"} | |
| model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") | |
| tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") | |
| device = 0 if torch.cuda.is_available() else -1 | |
| print(f"Is CUDA available: {torch.cuda.is_available()}") | |
| language_code_map = { | |
| "English": "eng_Latn", | |
| "French": "fra_Latn", | |
| "German": "deu_Latn", | |
| "Spanish": "spa_Latn", | |
| "Korean": "kor_Hang", | |
| "Japanese": "jpn_Jpan" | |
| } | |
| def identify_language(text): | |
| model_file = "lid218e.bin" | |
| model_full_path = os.path.join(os.path.dirname(__file__), model_file) | |
| model = fasttext.load_model(model_full_path) | |
| predictions = model.predict(text, k=1) # e.g., (('__label__eng_Latn',), array([0.81148803])) | |
| PREFIX_LENGTH = 7 # To strip away '__label__' from language code | |
| language_code = predictions[0][0][PREFIX_LENGTH:] | |
| return language_code | |
| def translate(text, src_lang, tgt_lang): | |
| src_lang_code = language_code_map[src_lang] | |
| tgt_lang_code = language_code_map[tgt_lang] | |
| translation_pipeline = pipeline( | |
| "translation", model=model, tokenizer=tokenizer, src_lang=src_lang_code, tgt_lang=tgt_lang_code, device=device) | |
| result = translation_pipeline(text) | |
| return result[0]['translation_text'] | |
| def query(text, src_lang, tgt_lang): | |
| translation = translate(text, src_lang, tgt_lang) | |
| lang_id_response = requests.post(LANG_ID_API_URL, headers=headers, json={ | |
| "inputs": text, "wait_for_model": True, "use_cache": True}) | |
| lang_id = lang_id_response.json()[0] | |
| language_code = identify_language(text) | |
| return [language_code, translation] | |
| examples = [ | |
| ["Hello, world", "English", "French"], | |
| ["Can I have a cheeseburger?", "English", "German"], | |
| ["Hasta la vista", "Spanish", "German"], | |
| ["동경에 휴가를 간다", "Korean", "Japanese"], | |
| ] | |
| gr.Interface( | |
| query, | |
| [ | |
| gr.Textbox(lines=2), | |
| gr.Radio(["English", "Spanish", "Korean"], value="English", label="Source Language"), | |
| gr.Radio(["French", "German", "Japanese"], value="French", label="Target Language") | |
| ], | |
| outputs=[ | |
| gr.Textbox(lines=3, label="Detected Language"), | |
| gr.Textbox(lines=3, label="Translation") | |
| ], | |
| title=title, | |
| description=description, | |
| examples=examples | |
| ).launch() | |