Spaces:
Runtime error
Runtime error
| from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer | |
| from tqdm import tqdm | |
| from typing import TextIO, List | |
| import argparse | |
| import torch | |
| from dataset import get_dataloader, count_lines | |
| import os | |
| def main( | |
| sentences_path, | |
| output_path, | |
| source_lang, | |
| target_lang, | |
| batch_size, | |
| model_name: str = "facebook/m2m100_1.2B", | |
| tensorrt: bool = False, | |
| precision: int = 32, | |
| max_length: int = 128, | |
| ): | |
| if not os.path.exists(os.path.dirname(output_path)): | |
| os.makedirs(os.path.dirname(output_path)) | |
| print("Loading tokenizer...") | |
| tokenizer = M2M100Tokenizer.from_pretrained(model_name) | |
| print("Loading model...") | |
| model = M2M100ForConditionalGeneration.from_pretrained(model_name) | |
| print(f"Model loaded.\n") | |
| tokenizer.src_lang = source_lang | |
| lang_code_to_idx = tokenizer.lang_code_to_id[target_lang] | |
| model.eval() | |
| total_lines: int = count_lines(sentences_path) | |
| print(f"We will translate {total_lines} lines.") | |
| data_loader = get_dataloader( | |
| filename=sentences_path, | |
| tokenizer=tokenizer, | |
| batch_size=batch_size, | |
| max_length=128, | |
| ) | |
| if precision == 16: | |
| dtype = torch.float16 | |
| elif precision == 32: | |
| dtype = torch.float32 | |
| elif precision == 64: | |
| dtype = torch.float64 | |
| else: | |
| raise ValueError("Precision must be 16, 32 or 64.") | |
| if tensorrt: | |
| device = "cuda" | |
| from torch2trt import torch2trt | |
| model.to(device, dtype=dtype) | |
| model = torch2trt( | |
| model, | |
| [torch.randn((batch_size, max_length)).to(device, dtype=torch.long)], | |
| ) | |
| else: | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| else: | |
| device = "cpu" | |
| print("CUDA not available. Using CPU. This will be slow.") | |
| model.to(device, dtype=dtype) | |
| with tqdm(total=total_lines, desc="Dataset translation") as pbar, open( | |
| output_path, "w+", encoding="utf-8" | |
| ) as output_file: | |
| with torch.no_grad(): | |
| for batch in data_loader: | |
| batch["input_ids"] = batch["input_ids"].to(device) | |
| batch["attention_mask"] = batch["attention_mask"].to(device) | |
| generated_tokens = model.generate( | |
| **batch, forced_bos_token_id=lang_code_to_idx | |
| ) | |
| tgt_text = tokenizer.batch_decode( | |
| generated_tokens.cpu(), skip_special_tokens=True | |
| ) | |
| print("\n".join(tgt_text), file=output_file) | |
| pbar.update(len(tgt_text)) | |
| print(f"Translation done.\n") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Run the translation experiments") | |
| parser.add_argument( | |
| "--sentences_path", | |
| type=str, | |
| required=True, | |
| help="Path to a txt file containing the sentences to translate. One sentence per line.", | |
| ) | |
| parser.add_argument( | |
| "--output_path", | |
| type=str, | |
| required=True, | |
| help="Path to a txt file where the translated sentences will be written.", | |
| ) | |
| parser.add_argument( | |
| "--source_lang", | |
| type=str, | |
| required=True, | |
| help="Source language id. See: https://huggingface.co/facebook/m2m100_1.2B", | |
| ) | |
| parser.add_argument( | |
| "--target_lang", | |
| type=str, | |
| required=True, | |
| help="Target language id. See: https://huggingface.co/facebook/m2m100_1.2B", | |
| ) | |
| parser.add_argument( | |
| "--batch_size", | |
| type=int, | |
| default=8, | |
| help="Batch size", | |
| ) | |
| parser.add_argument( | |
| "--model_name", | |
| type=str, | |
| default="facebook/m2m100_1.2B", | |
| help="Path to the model to use. See: https://huggingface.co/models", | |
| ) | |
| parser.add_argument( | |
| "--precision", | |
| type=int, | |
| default=32, | |
| choices=[16, 32, 64], | |
| help="Precision of the model. 16, 32 or 64.", | |
| ) | |
| parser.add_argument( | |
| "--tensorrt", | |
| action="store_true", | |
| help="Use TensorRT to compile the model.", | |
| ) | |
| args = parser.parse_args() | |
| main( | |
| sentences_path=args.sentences_path, | |
| output_path=args.output_path, | |
| source_lang=args.source_lang, | |
| target_lang=args.target_lang, | |
| batch_size=args.batch_size, | |
| model_name=args.model_name, | |
| precision=args.precision, | |
| tensorrt=args.tensorrt, | |
| ) | |