Spaces:
Running
on
Zero
Running
on
Zero
Optimize the preprocessing and generation (#11)
Browse files- harmonize the language codes list with NLLB (d0a2f64cdae2fae119a127dba13609cb1d0b7542)
- raise errors when the source or target language is not chosen (5c565ab3ea2711194390b6c1b06a499b7da4534e)
- adjust the generation parameters to avoid repetitions (d0ffdbfb40076436f5f40e7deffb7440f5c35e07)
- add punctuation normalization and load the tokenizer only once (2a62da0ac954875090a26ab5dacfef37e9000aec)
- use sentence splitters from stopes (3740b63b75a6c13c1e25911113565bbb51a584a6)
Co-authored-by: David Dale <cointegrated@users.noreply.huggingface.co>
- app.py +33 -9
- flores.py +3 -3
- requirements.txt +3 -1
app.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
import spaces
|
| 2 |
import gradio as gr
|
|
|
|
|
|
|
| 3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 4 |
from flores import code_mapping
|
| 5 |
import platform
|
|
@@ -28,28 +30,47 @@ def load_model():
|
|
| 28 |
model = load_model()
|
| 29 |
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
|
| 38 |
# cache function
|
| 39 |
@lru_cache(maxsize=100)
|
| 40 |
def translate(text: str, src_lang: str, tgt_lang: str):
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
# Only assign GPU if cache not used
|
| 44 |
@spaces.GPU
|
| 45 |
def _translate(text: str, src_lang: str, tgt_lang: str):
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
paragraphs = text.split("\n")
|
| 49 |
translated_paragraphs = []
|
| 50 |
|
| 51 |
for paragraph in paragraphs:
|
| 52 |
-
|
|
|
|
| 53 |
translated_sentences = []
|
| 54 |
|
| 55 |
for sentence in sentences:
|
|
@@ -62,9 +83,12 @@ def _translate(text: str, src_lang: str, tgt_lang: str):
|
|
| 62 |
)
|
| 63 |
translated_chunk = model.generate(
|
| 64 |
input_ids=torch.tensor([input_tokens]).to(device),
|
| 65 |
-
forced_bos_token_id=tokenizer.convert_tokens_to_ids(
|
| 66 |
max_length=len(input_tokens) + 50,
|
| 67 |
num_return_sequences=1,
|
|
|
|
|
|
|
|
|
|
| 68 |
)
|
| 69 |
translated_chunk = tokenizer.decode(
|
| 70 |
translated_chunk[0], skip_special_tokens=True
|
|
|
|
| 1 |
import spaces
|
| 2 |
import gradio as gr
|
| 3 |
+
from sacremoses import MosesPunctNormalizer
|
| 4 |
+
from stopes.pipelines.monolingual.utils.sentence_split import get_split_algo
|
| 5 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 6 |
from flores import code_mapping
|
| 7 |
import platform
|
|
|
|
| 30 |
model = load_model()
|
| 31 |
|
| 32 |
|
| 33 |
+
# Loading the tokenizer once, because re-loading it takes about 1.5 seconds each time
|
| 34 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
punct_normalizer = MosesPunctNormalizer(lang="en")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@lru_cache(maxsize=202)
|
| 41 |
+
def get_language_specific_sentence_splitter(language_code):
|
| 42 |
+
short_code = language_code[:3]
|
| 43 |
+
splitter = get_split_algo(short_code, "default")
|
| 44 |
+
return splitter
|
| 45 |
|
| 46 |
|
| 47 |
# cache function
|
| 48 |
@lru_cache(maxsize=100)
|
| 49 |
def translate(text: str, src_lang: str, tgt_lang: str):
|
| 50 |
+
if not src_lang:
|
| 51 |
+
raise gr.Error("The source language is empty! Please choose it in the dropdown list.")
|
| 52 |
+
if not tgt_lang:
|
| 53 |
+
raise gr.Error("The target language is empty! Please choose it in the dropdown list.")
|
| 54 |
+
return _translate(text, src_lang, tgt_lang)
|
| 55 |
+
|
| 56 |
|
| 57 |
# Only assign GPU if cache not used
|
| 58 |
@spaces.GPU
|
| 59 |
def _translate(text: str, src_lang: str, tgt_lang: str):
|
| 60 |
+
src_code = code_mapping[src_lang]
|
| 61 |
+
tgt_code = code_mapping[tgt_lang]
|
| 62 |
+
tokenizer.src_lang = src_code
|
| 63 |
+
tokenizer.tgt_lang = tgt_code
|
| 64 |
+
|
| 65 |
+
# normalizing the punctuation first
|
| 66 |
+
text = punct_normalizer.normalize(text)
|
| 67 |
|
| 68 |
paragraphs = text.split("\n")
|
| 69 |
translated_paragraphs = []
|
| 70 |
|
| 71 |
for paragraph in paragraphs:
|
| 72 |
+
splitter = get_language_specific_sentence_splitter(src_code)
|
| 73 |
+
sentences = list(splitter(paragraph))
|
| 74 |
translated_sentences = []
|
| 75 |
|
| 76 |
for sentence in sentences:
|
|
|
|
| 83 |
)
|
| 84 |
translated_chunk = model.generate(
|
| 85 |
input_ids=torch.tensor([input_tokens]).to(device),
|
| 86 |
+
forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_code),
|
| 87 |
max_length=len(input_tokens) + 50,
|
| 88 |
num_return_sequences=1,
|
| 89 |
+
num_beams=5,
|
| 90 |
+
no_repeat_ngram_size=4, # repetition blocking works better if this number is below num_beams
|
| 91 |
+
renormalize_logits=True, # recompute token probabilities after banning the repetitions
|
| 92 |
)
|
| 93 |
translated_chunk = tokenizer.decode(
|
| 94 |
translated_chunk[0], skip_special_tokens=True
|
flores.py
CHANGED
|
@@ -10,7 +10,7 @@ code_mapping = {
|
|
| 10 |
"Amharic": "amh_Ethi",
|
| 11 |
"North Levantine Arabic": "apc_Arab",
|
| 12 |
"Modern Standard Arabic": "arb_Arab",
|
| 13 |
-
"Modern Standard Arabic (Romanized)": "arb_Latn",
|
| 14 |
"Najdi Arabic": "ars_Arab",
|
| 15 |
"Moroccan Arabic": "ary_Arab",
|
| 16 |
"Egyptian Arabic": "arz_Arab",
|
|
@@ -115,7 +115,7 @@ code_mapping = {
|
|
| 115 |
"Maithili": "mai_Deva",
|
| 116 |
"Malayalam": "mal_Mlym",
|
| 117 |
"Marathi": "mar_Deva",
|
| 118 |
-
"Minangkabau (Arabic script)": "min_Arab",
|
| 119 |
"Minangkabau (Latin script)": "min_Latn",
|
| 120 |
"Macedonian": "mkd_Cyrl",
|
| 121 |
"Plateau Malagasy": "plt_Latn",
|
|
@@ -149,7 +149,7 @@ code_mapping = {
|
|
| 149 |
"Russian": "rus_Cyrl",
|
| 150 |
"Sango": "sag_Latn",
|
| 151 |
"Sanskrit": "san_Deva",
|
| 152 |
-
"Santali": "
|
| 153 |
"Sicilian": "scn_Latn",
|
| 154 |
"Shan": "shn_Mymr",
|
| 155 |
"Sinhala": "sin_Sinh",
|
|
|
|
| 10 |
"Amharic": "amh_Ethi",
|
| 11 |
"North Levantine Arabic": "apc_Arab",
|
| 12 |
"Modern Standard Arabic": "arb_Arab",
|
| 13 |
+
# "Modern Standard Arabic (Romanized)": "arb_Latn", # it is in FLORES, but not in NLLB
|
| 14 |
"Najdi Arabic": "ars_Arab",
|
| 15 |
"Moroccan Arabic": "ary_Arab",
|
| 16 |
"Egyptian Arabic": "arz_Arab",
|
|
|
|
| 115 |
"Maithili": "mai_Deva",
|
| 116 |
"Malayalam": "mal_Mlym",
|
| 117 |
"Marathi": "mar_Deva",
|
| 118 |
+
# "Minangkabau (Arabic script)": "min_Arab", # it is in FLORES, but not in NLLB
|
| 119 |
"Minangkabau (Latin script)": "min_Latn",
|
| 120 |
"Macedonian": "mkd_Cyrl",
|
| 121 |
"Plateau Malagasy": "plt_Latn",
|
|
|
|
| 149 |
"Russian": "rus_Cyrl",
|
| 150 |
"Sango": "sag_Latn",
|
| 151 |
"Sanskrit": "san_Deva",
|
| 152 |
+
"Santali": "sat_Beng", # It is called sat_Olck in FLORES, but (less correctly sat_Beng in NLLB)
|
| 153 |
"Sicilian": "scn_Latn",
|
| 154 |
"Shan": "shn_Mymr",
|
| 155 |
"Sinhala": "sin_Sinh",
|
requirements.txt
CHANGED
|
@@ -3,4 +3,6 @@ transformers
|
|
| 3 |
torch
|
| 4 |
gradio==4.32.2
|
| 5 |
spaces
|
| 6 |
-
nltk
|
|
|
|
|
|
|
|
|
| 3 |
torch
|
| 4 |
gradio==4.32.2
|
| 5 |
spaces
|
| 6 |
+
nltk
|
| 7 |
+
sacremoses
|
| 8 |
+
stopes[mono] @ git+https://github.com/facebookresearch/stopes@better-sentence-splitters
|