Spaces:
Running
Running
| import os | |
| import warnings | |
| import huggingface_hub | |
| import requests | |
| import torch | |
| import ctranslate2 | |
| import transformers | |
| from typing import Optional | |
| from src.config import ModelConfig | |
| from src.translation.translationLangs import TranslationLang, get_lang_from_whisper_code | |
| class TranslationModel: | |
| def __init__( | |
| self, | |
| modelConfig: ModelConfig, | |
| device: str = None, | |
| whisperLang: TranslationLang = None, | |
| translationLang: TranslationLang = None, | |
| batchSize: int = 2, | |
| noRepeatNgramSize: int = 3, | |
| numBeams: int = 2, | |
| downloadRoot: Optional[str] = None, | |
| localFilesOnly: bool = False, | |
| loadModel: bool = False, | |
| ): | |
| """Initializes the M2M100 / Nllb-200 / mt5 model. | |
| Args: | |
| modelConfig: Config of the model to use (distilled-600M, distilled-1.3B, | |
| 1.3B, 3.3B...) or a path to a converted | |
| model directory. When a size is configured, the converted model is downloaded | |
| from the Hugging Face Hub. | |
| device: Device to use for computation (cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, | |
| ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia). | |
| device_index: Device ID to use. | |
| The model can also be loaded on multiple GPUs by passing a list of IDs | |
| (e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can run in parallel | |
| when transcribe() is called from multiple Python threads (see also num_workers). | |
| compute_type: Type to use for computation. | |
| See https://opennmt.net/CTranslate2/quantization.html. | |
| cpu_threads: Number of threads to use when running on CPU (4 by default). | |
| A non zero value overrides the OMP_NUM_THREADS environment variable. | |
| num_workers: When transcribe() is called from multiple Python threads, | |
| having multiple workers enables true parallelism when running the model | |
| (concurrent calls to self.model.generate() will run in parallel). | |
| This can improve the global throughput at the cost of increased memory usage. | |
| downloadRoot: Directory where the models should be saved. If not set, the models | |
| are saved in the standard Hugging Face cache directory. | |
| localFilesOnly: If True, avoid downloading the file and return the path to the | |
| local cached file if it exists. | |
| """ | |
| self.modelConfig = modelConfig | |
| self.whisperLang = whisperLang # self.translationLangWhisper = get_lang_from_whisper_code(whisperLang.code.lower() if whisperLang is not None else "en") | |
| self.translationLang = translationLang | |
| if translationLang is None: | |
| return | |
| self.batchSize = batchSize | |
| self.noRepeatNgramSize = noRepeatNgramSize | |
| self.numBeams = numBeams | |
| if os.path.isdir(modelConfig.url): | |
| self.modelPath = modelConfig.url | |
| else: | |
| self.modelPath = download_model( | |
| modelConfig, | |
| localFilesOnly=localFilesOnly, | |
| cacheDir=downloadRoot, | |
| ) | |
| if device is None: | |
| if torch.cuda.is_available(): | |
| device = "cuda" if "ct2" in self.modelPath else "cuda:0" | |
| else: | |
| device = "cpu" | |
| self.device = device | |
| if loadModel: | |
| self.load_model() | |
| def load_model(self): | |
| print('\n\nLoading model: %s\n\n' % self.modelPath) | |
| if "ct2" in self.modelPath: | |
| if "nllb" in self.modelPath: | |
| self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.nllb.code) | |
| self.targetPrefix = [self.translationLang.nllb.code] | |
| elif "m2m100" in self.modelPath: | |
| self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.m2m100.code) | |
| self.targetPrefix = [self.transTokenizer.lang_code_to_token[self.translationLang.m2m100.code]] | |
| self.transModel = ctranslate2.Translator(self.modelPath, compute_type="auto", device=self.device) | |
| elif "mt5" in self.modelPath: | |
| self.mt5Prefix = self.whisperLang.whisper.code + "2" + self.translationLang.whisper.code + ": " | |
| self.transTokenizer = transformers.T5Tokenizer.from_pretrained(self.modelPath, legacy=False) #requires spiece.model | |
| self.transModel = transformers.MT5ForConditionalGeneration.from_pretrained(self.modelPath) | |
| self.transTranslator = transformers.pipeline('text2text-generation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer) | |
| else: | |
| self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath) | |
| self.transModel = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.modelPath) | |
| if "m2m100" in self.modelPath: | |
| self.transTranslator = transformers.pipeline('translation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer, src_lang=self.whisperLang.m2m100.code, tgt_lang=self.translationLang.m2m100.code) | |
| else: #NLLB | |
| self.transTranslator = transformers.pipeline('translation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer, src_lang=self.whisperLang.nllb.code, tgt_lang=self.translationLang.nllb.code) | |
| def release_vram(self): | |
| try: | |
| if torch.cuda.is_available(): | |
| if "ct2" not in self.modelPath: | |
| device = torch.device("cpu") | |
| self.transModel.to(device) | |
| del self.transModel | |
| torch.cuda.empty_cache() | |
| print("release vram end.") | |
| except Exception as e: | |
| print("Error release vram: " + str(e)) | |
| def translation(self, text: str, max_length: int = 400): | |
| output = None | |
| result = None | |
| try: | |
| if "ct2" in self.modelPath: | |
| source = self.transTokenizer.convert_ids_to_tokens(self.transTokenizer.encode(text)) | |
| output = self.transModel.translate_batch([source], target_prefix=[self.targetPrefix], max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams) | |
| target = output[0].hypotheses[0][1:] | |
| result = self.transTokenizer.decode(self.transTokenizer.convert_tokens_to_ids(target)) | |
| elif "mt5" in self.modelPath: | |
| output = self.transTranslator(self.mt5Prefix + text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) #, num_return_sequences=2 | |
| result = output[0]['generated_text'] | |
| else: #M2M100 & NLLB | |
| output = self.transTranslator(text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) | |
| result = output[0]['translation_text'] | |
| except Exception as e: | |
| print("Error translation text: " + str(e)) | |
| return result | |
| _MODELS = ["distilled-600M", "distilled-1.3B", "1.3B", "3.3B", | |
| "ct2fast-nllb-200-distilled-1.3B-int8_float16", | |
| "ct2fast-nllb-200-3.3B-int8_float16", | |
| "nllb-200-3.3B-ct2-float16", "nllb-200-1.3B-ct2", "nllb-200-1.3B-ct2-int8", "nllb-200-1.3B-ct2-float16", | |
| "nllb-200-distilled-1.3B-ct2", "nllb-200-distilled-1.3B-ct2-int8", "nllb-200-distilled-1.3B-ct2-float16", | |
| "nllb-200-distilled-600M-ct2", "nllb-200-distilled-600M-ct2-int8", "nllb-200-distilled-600M-ct2-float16", | |
| "m2m100_1.2B-ct2", "m2m100_418M-ct2", "m2m100-12B-ct2", | |
| "m2m100_1.2B", "m2m100_418M", | |
| "mt5-zh-ja-en-trimmed", | |
| "mt5-zh-ja-en-trimmed-fine-tuned-v1"] | |
| def check_model_name(name): | |
| return any(allowed_name in name for allowed_name in _MODELS) | |
| def download_model( | |
| modelConfig: ModelConfig, | |
| outputDir: Optional[str] = None, | |
| localFilesOnly: bool = False, | |
| cacheDir: Optional[str] = None, | |
| ): | |
| """"download_model" is referenced from the "utils.py" script | |
| of the "faster_whisper" project, authored by guillaumekln. | |
| Downloads a nllb-200 model from the Hugging Face Hub. | |
| The model is downloaded from https://huggingface.co/facebook. | |
| Args: | |
| modelConfig: config of the model to download (facebook/nllb-distilled-600M, | |
| facebook/nllb-distilled-1.3B, facebook/nllb-1.3B, facebook/nllb-3.3B...). | |
| outputDir: Directory where the model should be saved. If not set, the model is saved in | |
| the cache directory. | |
| localFilesOnly: If True, avoid downloading the file and return the path to the local | |
| cached file if it exists. | |
| cacheDir: Path to the folder where cached files are stored. | |
| Returns: | |
| The path to the downloaded model. | |
| Raises: | |
| ValueError: if the model size is invalid. | |
| """ | |
| if not check_model_name(modelConfig.name): | |
| raise ValueError( | |
| "Invalid model name '%s', expected one of: %s" % (modelConfig.name, ", ".join(_MODELS)) | |
| ) | |
| repoId = modelConfig.url #"facebook/nllb-200-%s" % | |
| allowPatterns = [ | |
| "config.json", | |
| "generation_config.json", | |
| "model.bin", | |
| "pytorch_model.bin", | |
| "pytorch_model.bin.index.json", | |
| "pytorch_model-*.bin", | |
| "pytorch_model-00001-of-00003.bin", | |
| "pytorch_model-00002-of-00003.bin", | |
| "pytorch_model-00003-of-00003.bin", | |
| "sentencepiece.bpe.model", | |
| "tokenizer.json", | |
| "tokenizer_config.json", | |
| "shared_vocabulary.txt", | |
| "shared_vocabulary.json", | |
| "special_tokens_map.json", | |
| "spiece.model", | |
| "vocab.json", #m2m100 | |
| ] | |
| kwargs = { | |
| "local_files_only": localFilesOnly, | |
| "allow_patterns": allowPatterns, | |
| #"tqdm_class": disabled_tqdm, | |
| } | |
| if outputDir is not None: | |
| kwargs["local_dir"] = outputDir | |
| kwargs["local_dir_use_symlinks"] = False | |
| if cacheDir is not None: | |
| kwargs["cache_dir"] = cacheDir | |
| try: | |
| return huggingface_hub.snapshot_download(repoId, **kwargs) | |
| except ( | |
| huggingface_hub.utils.HfHubHTTPError, | |
| requests.exceptions.ConnectionError, | |
| ) as exception: | |
| warnings.warn( | |
| "An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s", | |
| repoId, | |
| exception, | |
| ) | |
| warnings.warn( | |
| "Trying to load the model directly from the local cache, if it exists." | |
| ) | |
| kwargs["local_files_only"] = True | |
| return huggingface_hub.snapshot_download(repoId, **kwargs) | |