Angwolfrust commited on
Commit
9e5cadc
·
verified ·
1 Parent(s): 477bca3

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -0
  2. utils.py +92 -0
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ gradio>=4.0.0
utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import functools
4
+
5
+ # Define a subset of popular languages mapped to FLORES-200 codes for better UX.
6
+ # NLLB supports 200+, but a dropdown of 200 items can be unwieldy.
7
+ # Codes reference: https://github.com/facebookresearch/flores/blob/main/flores200/README.md
8
+ LANGUAGE_CODES = {
9
+ "English": "eng_Latn",
10
+ "French": "fra_Latn",
11
+ "Spanish": "spa_Latn",
12
+ "German": "deu_Latn",
13
+ "Chinese (Simplified)": "zho_Hans",
14
+ "Chinese (Traditional)": "zho_Hant",
15
+ "Hindi": "hin_Deva",
16
+ "Arabic": "arb_Arab",
17
+ "Russian": "rus_Cyrl",
18
+ "Portuguese": "por_Latn",
19
+ "Japanese": "jpn_Jpan",
20
+ "Korean": "kor_Hang",
21
+ "Italian": "ita_Latn",
22
+ "Dutch": "nld_Latn",
23
+ "Turkish": "tur_Latn",
24
+ "Vietnamese": "vie_Latn",
25
+ "Indonesian": "ind_Latn",
26
+ "Persian": "pes_Arab",
27
+ "Polish": "pol_Latn",
28
+ "Ukrainian": "ukr_Cyrl",
29
+ "Swahili": "swh_Latn",
30
+ "Urdu": "urd_Arab",
31
+ "Bengali": "ben_Beng",
32
+ "Tamil": "tam_Taml"
33
+ }
34
+
35
+ MODEL_NAME = "facebook/nllb-200-distilled-600M"
36
+ _model = None
37
+ _tokenizer = None
38
+
39
+ def get_device():
40
+ """Determines the best available device."""
41
+ if torch.cuda.is_available():
42
+ return "cuda"
43
+ elif torch.backends.mps.is_available():
44
+ return "mps"
45
+ return "cpu"
46
+
47
+ def load_model():
48
+ """
49
+ Loads the model and tokenizer lazily (singleton pattern).
50
+ """
51
+ global _model, _tokenizer
52
+ if _model is None:
53
+ print(f"Loading {MODEL_NAME}...")
54
+ device = get_device()
55
+ _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
56
+ _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)
57
+ print("Model loaded successfully.")
58
+ return _model, _tokenizer
59
+
60
+ def translate_text(text, src_lang_name, tgt_lang_name):
61
+ """
62
+ Performs the translation using NLLB.
63
+ """
64
+ if not text:
65
+ return ""
66
+
67
+ try:
68
+ model, tokenizer = load_model()
69
+ device = model.device
70
+
71
+ # Get NLLB specific codes
72
+ src_code = LANGUAGE_CODES.get(src_lang_name, "eng_Latn")
73
+ tgt_code = LANGUAGE_CODES.get(tgt_lang_name, "fra_Latn")
74
+
75
+ # Prepare inputs
76
+ tokenizer.src_lang = src_code
77
+ inputs = tokenizer(text, return_tensors="pt").to(device)
78
+
79
+ # Generate translation
80
+ # forced_bos_token_id forces the model to start generating in the target language
81
+ generated_tokens = model.generate(
82
+ **inputs,
83
+ forced_bos_token_id=tokenizer.lang_code_to_id[tgt_code],
84
+ max_length=200
85
+ )
86
+
87
+ # Decode output
88
+ result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
89
+ return result
90
+
91
+ except Exception as e:
92
+ return f"Error during translation: {str(e)}"