Spaces:
Running
Running
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
| import gradio as gr | |
| import torch | |
| import os | |
| from datetime import datetime | |
| from .base_interface import BaseInterface | |
| from modules.subtitle_manager import * | |
| DEFAULT_MODEL_SIZE = "facebook/nllb-200-1.3B" | |
| NLLB_MODELS = ["facebook/nllb-200-3.3B", "facebook/nllb-200-1.3B", "facebook/nllb-200-distilled-600M"] | |
| class NLLBInference(BaseInterface): | |
| def __init__(self): | |
| super().__init__() | |
| self.default_model_size = DEFAULT_MODEL_SIZE | |
| self.current_model_size = None | |
| self.model = None | |
| self.tokenizer = None | |
| self.available_models = NLLB_MODELS | |
| self.available_source_langs = list(NLLB_AVAILABLE_LANGS.keys()) | |
| self.available_target_langs = list(NLLB_AVAILABLE_LANGS.keys()) | |
| self.device = 0 if torch.cuda.is_available() else -1 | |
| self.pipeline = None | |
| def translate_text(self, text): | |
| result = self.pipeline(text) | |
| return result[0]['translation_text'] | |
| def translate_file(self, | |
| fileobjs: list, | |
| model_size: str, | |
| src_lang: str, | |
| tgt_lang: str, | |
| add_timestamp: bool, | |
| progress=gr.Progress()): | |
| """ | |
| Translate subtitle file from source language to target language | |
| Parameters | |
| ---------- | |
| fileobjs: list | |
| List of files to transcribe from gr.Files() | |
| model_size: str | |
| Whisper model size from gr.Dropdown() | |
| src_lang: str | |
| Source language of the file to translate from gr.Dropdown() | |
| tgt_lang: str | |
| Target language of the file to translate from gr.Dropdown() | |
| add_timestamp: bool | |
| Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename. | |
| progress: gr.Progress | |
| Indicator to show progress directly in gradio. | |
| I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback | |
| """ | |
| try: | |
| if model_size != self.current_model_size or self.model is None: | |
| print("\nInitializing NLLB Model..\n") | |
| progress(0, desc="Initializing NLLB Model..") | |
| self.current_model_size = model_size | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_size, | |
| cache_dir=os.path.join("models", "NLLB")) | |
| self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size, | |
| cache_dir=os.path.join("models", "NLLB", "tokenizers")) | |
| src_lang = NLLB_AVAILABLE_LANGS[src_lang] | |
| tgt_lang = NLLB_AVAILABLE_LANGS[tgt_lang] | |
| self.pipeline = pipeline("translation", | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| src_lang=src_lang, | |
| tgt_lang=tgt_lang, | |
| device=self.device) | |
| files_info = {} | |
| for fileobj in fileobjs: | |
| file_path = fileobj.name | |
| file_name, file_ext = os.path.splitext(os.path.basename(fileobj.orig_name)) | |
| if file_ext == ".srt": | |
| parsed_dicts = parse_srt(file_path=file_path) | |
| total_progress = len(parsed_dicts) | |
| for index, dic in enumerate(parsed_dicts): | |
| progress(index / total_progress, desc="Translating..") | |
| translated_text = self.translate_text(dic["sentence"]) | |
| dic["sentence"] = translated_text | |
| subtitle = get_serialized_srt(parsed_dicts) | |
| timestamp = datetime.now().strftime("%m%d%H%M%S") | |
| if add_timestamp: | |
| output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}") | |
| else: | |
| output_path = os.path.join("outputs", "translations", f"{file_name}") | |
| write_file(subtitle, f"{output_path}.srt") | |
| elif file_ext == ".vtt": | |
| parsed_dicts = parse_vtt(file_path=file_path) | |
| total_progress = len(parsed_dicts) | |
| for index, dic in enumerate(parsed_dicts): | |
| progress(index / total_progress, desc="Translating..") | |
| translated_text = self.translate_text(dic["sentence"]) | |
| dic["sentence"] = translated_text | |
| subtitle = get_serialized_vtt(parsed_dicts) | |
| timestamp = datetime.now().strftime("%m%d%H%M%S") | |
| if add_timestamp: | |
| output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}") | |
| else: | |
| output_path = os.path.join("outputs", "translations", f"{file_name}") | |
| write_file(subtitle, f"{output_path}.vtt") | |
| files_info[file_name] = subtitle | |
| total_result = '' | |
| for file_name, subtitle in files_info.items(): | |
| total_result += '------------------------------------\n' | |
| total_result += f'{file_name}\n\n' | |
| total_result += f'{subtitle}' | |
| return f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}" | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| finally: | |
| self.release_cuda_memory() | |
| self.remove_input_files([fileobj.name for fileobj in fileobjs]) | |
| NLLB_AVAILABLE_LANGS = { | |
| "Acehnese (Arabic script)": "ace_Arab", | |
| "Acehnese (Latin script)": "ace_Latn", | |
| "Mesopotamian Arabic": "acm_Arab", | |
| "Ta’izzi-Adeni Arabic": "acq_Arab", | |
| "Tunisian Arabic": "aeb_Arab", | |
| "Afrikaans": "afr_Latn", | |
| "South Levantine Arabic": "ajp_Arab", | |
| "Akan": "aka_Latn", | |
| "Amharic": "amh_Ethi", | |
| "North Levantine Arabic": "apc_Arab", | |
| "Modern Standard Arabic": "arb_Arab", | |
| "Modern Standard Arabic (Romanized)": "arb_Latn", | |
| "Najdi Arabic": "ars_Arab", | |
| "Moroccan Arabic": "ary_Arab", | |
| "Egyptian Arabic": "arz_Arab", | |
| "Assamese": "asm_Beng", | |
| "Asturian": "ast_Latn", | |
| "Awadhi": "awa_Deva", | |
| "Central Aymara": "ayr_Latn", | |
| "South Azerbaijani": "azb_Arab", | |
| "North Azerbaijani": "azj_Latn", | |
| "Bashkir": "bak_Cyrl", | |
| "Bambara": "bam_Latn", | |
| "Balinese": "ban_Latn", | |
| "Belarusian": "bel_Cyrl", | |
| "Bemba": "bem_Latn", | |
| "Bengali": "ben_Beng", | |
| "Bhojpuri": "bho_Deva", | |
| "Banjar (Arabic script)": "bjn_Arab", | |
| "Banjar (Latin script)": "bjn_Latn", | |
| "Standard Tibetan": "bod_Tibt", | |
| "Bosnian": "bos_Latn", | |
| "Buginese": "bug_Latn", | |
| "Bulgarian": "bul_Cyrl", | |
| "Catalan": "cat_Latn", | |
| "Cebuano": "ceb_Latn", | |
| "Czech": "ces_Latn", | |
| "Chokwe": "cjk_Latn", | |
| "Central Kurdish": "ckb_Arab", | |
| "Crimean Tatar": "crh_Latn", | |
| "Welsh": "cym_Latn", | |
| "Danish": "dan_Latn", | |
| "German": "deu_Latn", | |
| "Southwestern Dinka": "dik_Latn", | |
| "Dyula": "dyu_Latn", | |
| "Dzongkha": "dzo_Tibt", | |
| "Greek": "ell_Grek", | |
| "English": "eng_Latn", | |
| "Esperanto": "epo_Latn", | |
| "Estonian": "est_Latn", | |
| "Basque": "eus_Latn", | |
| "Ewe": "ewe_Latn", | |
| "Faroese": "fao_Latn", | |
| "Fijian": "fij_Latn", | |
| "Finnish": "fin_Latn", | |
| "Fon": "fon_Latn", | |
| "French": "fra_Latn", | |
| "Friulian": "fur_Latn", | |
| "Nigerian Fulfulde": "fuv_Latn", | |
| "Scottish Gaelic": "gla_Latn", | |
| "Irish": "gle_Latn", | |
| "Galician": "glg_Latn", | |
| "Guarani": "grn_Latn", | |
| "Gujarati": "guj_Gujr", | |
| "Haitian Creole": "hat_Latn", | |
| "Hausa": "hau_Latn", | |
| "Hebrew": "heb_Hebr", | |
| "Hindi": "hin_Deva", | |
| "Chhattisgarhi": "hne_Deva", | |
| "Croatian": "hrv_Latn", | |
| "Hungarian": "hun_Latn", | |
| "Armenian": "hye_Armn", | |
| "Igbo": "ibo_Latn", | |
| "Ilocano": "ilo_Latn", | |
| "Indonesian": "ind_Latn", | |
| "Icelandic": "isl_Latn", | |
| "Italian": "ita_Latn", | |
| "Javanese": "jav_Latn", | |
| "Japanese": "jpn_Jpan", | |
| "Kabyle": "kab_Latn", | |
| "Jingpho": "kac_Latn", | |
| "Kamba": "kam_Latn", | |
| "Kannada": "kan_Knda", | |
| "Kashmiri (Arabic script)": "kas_Arab", | |
| "Kashmiri (Devanagari script)": "kas_Deva", | |
| "Georgian": "kat_Geor", | |
| "Central Kanuri (Arabic script)": "knc_Arab", | |
| "Central Kanuri (Latin script)": "knc_Latn", | |
| "Kazakh": "kaz_Cyrl", | |
| "Kabiyè": "kbp_Latn", | |
| "Kabuverdianu": "kea_Latn", | |
| "Khmer": "khm_Khmr", | |
| "Kikuyu": "kik_Latn", | |
| "Kinyarwanda": "kin_Latn", | |
| "Kyrgyz": "kir_Cyrl", | |
| "Kimbundu": "kmb_Latn", | |
| "Northern Kurdish": "kmr_Latn", | |
| "Kikongo": "kon_Latn", | |
| "Korean": "kor_Hang", | |
| "Lao": "lao_Laoo", | |
| "Ligurian": "lij_Latn", | |
| "Limburgish": "lim_Latn", | |
| "Lingala": "lin_Latn", | |
| "Lithuanian": "lit_Latn", | |
| "Lombard": "lmo_Latn", | |
| "Latgalian": "ltg_Latn", | |
| "Luxembourgish": "ltz_Latn", | |
| "Luba-Kasai": "lua_Latn", | |
| "Ganda": "lug_Latn", | |
| "Luo": "luo_Latn", | |
| "Mizo": "lus_Latn", | |
| "Standard Latvian": "lvs_Latn", | |
| "Magahi": "mag_Deva", | |
| "Maithili": "mai_Deva", | |
| "Malayalam": "mal_Mlym", | |
| "Marathi": "mar_Deva", | |
| "Minangkabau (Arabic script)": "min_Arab", | |
| "Minangkabau (Latin script)": "min_Latn", | |
| "Macedonian": "mkd_Cyrl", | |
| "Plateau Malagasy": "plt_Latn", | |
| "Maltese": "mlt_Latn", | |
| "Meitei (Bengali script)": "mni_Beng", | |
| "Halh Mongolian": "khk_Cyrl", | |
| "Mossi": "mos_Latn", | |
| "Maori": "mri_Latn", | |
| "Burmese": "mya_Mymr", | |
| "Dutch": "nld_Latn", | |
| "Norwegian Nynorsk": "nno_Latn", | |
| "Norwegian Bokmål": "nob_Latn", | |
| "Nepali": "npi_Deva", | |
| "Northern Sotho": "nso_Latn", | |
| "Nuer": "nus_Latn", | |
| "Nyanja": "nya_Latn", | |
| "Occitan": "oci_Latn", | |
| "West Central Oromo": "gaz_Latn", | |
| "Odia": "ory_Orya", | |
| "Pangasinan": "pag_Latn", | |
| "Eastern Panjabi": "pan_Guru", | |
| "Papiamento": "pap_Latn", | |
| "Western Persian": "pes_Arab", | |
| "Polish": "pol_Latn", | |
| "Portuguese": "por_Latn", | |
| "Dari": "prs_Arab", | |
| "Southern Pashto": "pbt_Arab", | |
| "Ayacucho Quechua": "quy_Latn", | |
| "Romanian": "ron_Latn", | |
| "Rundi": "run_Latn", | |
| "Russian": "rus_Cyrl", | |
| "Sango": "sag_Latn", | |
| "Sanskrit": "san_Deva", | |
| "Santali": "sat_Olck", | |
| "Sicilian": "scn_Latn", | |
| "Shan": "shn_Mymr", | |
| "Sinhala": "sin_Sinh", | |
| "Slovak": "slk_Latn", | |
| "Slovenian": "slv_Latn", | |
| "Samoan": "smo_Latn", | |
| "Shona": "sna_Latn", | |
| "Sindhi": "snd_Arab", | |
| "Somali": "som_Latn", | |
| "Southern Sotho": "sot_Latn", | |
| "Spanish": "spa_Latn", | |
| "Tosk Albanian": "als_Latn", | |
| "Sardinian": "srd_Latn", | |
| "Serbian": "srp_Cyrl", | |
| "Swati": "ssw_Latn", | |
| "Sundanese": "sun_Latn", | |
| "Swedish": "swe_Latn", | |
| "Swahili": "swh_Latn", | |
| "Silesian": "szl_Latn", | |
| "Tamil": "tam_Taml", | |
| "Tatar": "tat_Cyrl", | |
| "Telugu": "tel_Telu", | |
| "Tajik": "tgk_Cyrl", | |
| "Tagalog": "tgl_Latn", | |
| "Thai": "tha_Thai", | |
| "Tigrinya": "tir_Ethi", | |
| "Tamasheq (Latin script)": "taq_Latn", | |
| "Tamasheq (Tifinagh script)": "taq_Tfng", | |
| "Tok Pisin": "tpi_Latn", | |
| "Tswana": "tsn_Latn", | |
| "Tsonga": "tso_Latn", | |
| "Turkmen": "tuk_Latn", | |
| "Tumbuka": "tum_Latn", | |
| "Turkish": "tur_Latn", | |
| "Twi": "twi_Latn", | |
| "Central Atlas Tamazight": "tzm_Tfng", | |
| "Uyghur": "uig_Arab", | |
| "Ukrainian": "ukr_Cyrl", | |
| "Umbundu": "umb_Latn", | |
| "Urdu": "urd_Arab", | |
| "Northern Uzbek": "uzn_Latn", | |
| "Venetian": "vec_Latn", | |
| "Vietnamese": "vie_Latn", | |
| "Waray": "war_Latn", | |
| "Wolof": "wol_Latn", | |
| "Xhosa": "xho_Latn", | |
| "Eastern Yiddish": "ydd_Hebr", | |
| "Yoruba": "yor_Latn", | |
| "Yue Chinese": "yue_Hant", | |
| "Chinese (Simplified)": "zho_Hans", | |
| "Chinese (Traditional)": "zho_Hant", | |
| "Standard Malay": "zsm_Latn", | |
| "Zulu": "zul_Latn", | |
| } | |