Spaces:
Paused
Paused
implement streaming
Browse files
app.py
CHANGED
|
@@ -1,15 +1,16 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 3 |
from transformers.cache_utils import DynamicCache
|
| 4 |
import torch
|
| 5 |
import itertools
|
|
|
|
| 6 |
|
| 7 |
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 8 |
MODEL_IDS = [
|
| 9 |
"70M",
|
| 10 |
"160M",
|
| 11 |
"410M",
|
| 12 |
-
"610M",
|
| 13 |
]
|
| 14 |
MODEL_MAPPING = {
|
| 15 |
model_id: f"LinguaCustodia/multilingual-multidomain-fin-mt-{model_id}"
|
|
@@ -211,31 +212,28 @@ def translate_with_model(model_name, text, tgt_lang, src_lang, domain):
|
|
| 211 |
src_lang_token_pos = domain_token_pos - 1
|
| 212 |
_tgt_lang_token_pos = src_lang_token_pos - 1
|
| 213 |
|
| 214 |
-
|
|
|
|
| 215 |
input_ids=inputs["input_ids"],
|
| 216 |
attention_mask=inputs["attention_mask"],
|
| 217 |
num_beams=1,
|
| 218 |
max_new_tokens=500,
|
| 219 |
-
pad_token_id=TOKENIZER.pad_token_id,
|
| 220 |
-
eos_token_id=TOKENIZER.eos_token_id,
|
| 221 |
past_key_values=past_key_values,
|
|
|
|
| 222 |
)
|
|
|
|
|
|
|
| 223 |
|
| 224 |
-
generated_translation =
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
source_language_token = TOKENIZER.convert_ids_to_tokens(
|
| 229 |
-
outputs[0, src_lang_token_pos].item()
|
| 230 |
-
)
|
| 231 |
-
dom_token = TOKENIZER.convert_ids_to_tokens(outputs[0, domain_token_pos].item())
|
| 232 |
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
|
| 240 |
|
| 241 |
def translate_with_all_models(selected_models, text, tgt_lang, src_lang, domain):
|
|
@@ -257,11 +255,13 @@ def translate_with_all_models(selected_models, text, tgt_lang, src_lang, domain)
|
|
| 257 |
|
| 258 |
for model_id in selected_models:
|
| 259 |
i = MODEL_INDEX[model_id]
|
| 260 |
-
model_output
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
|
|
|
|
|
|
| 265 |
|
| 266 |
|
| 267 |
with gr.Blocks() as demo:
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
| 3 |
from transformers.cache_utils import DynamicCache
|
| 4 |
import torch
|
| 5 |
import itertools
|
| 6 |
+
from threading import Thread
|
| 7 |
|
| 8 |
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 9 |
MODEL_IDS = [
|
| 10 |
"70M",
|
| 11 |
"160M",
|
| 12 |
"410M",
|
| 13 |
+
# "610M",
|
| 14 |
]
|
| 15 |
MODEL_MAPPING = {
|
| 16 |
model_id: f"LinguaCustodia/multilingual-multidomain-fin-mt-{model_id}"
|
|
|
|
| 212 |
src_lang_token_pos = domain_token_pos - 1
|
| 213 |
_tgt_lang_token_pos = src_lang_token_pos - 1
|
| 214 |
|
| 215 |
+
streamer = TextIteratorStreamer(TOKENIZER, skip_prompt=True)
|
| 216 |
+
generation_kwargs = dict(
|
| 217 |
input_ids=inputs["input_ids"],
|
| 218 |
attention_mask=inputs["attention_mask"],
|
| 219 |
num_beams=1,
|
| 220 |
max_new_tokens=500,
|
|
|
|
|
|
|
| 221 |
past_key_values=past_key_values,
|
| 222 |
+
streamer=streamer,
|
| 223 |
)
|
| 224 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
| 225 |
+
thread.start()
|
| 226 |
|
| 227 |
+
generated_translation = ""
|
| 228 |
+
for new_text in streamer:
|
| 229 |
+
generated_translation += new_text.replace(TOKENIZER.eos_token, "")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
+
yield {
|
| 232 |
+
"model": model_name,
|
| 233 |
+
"source_lang": CODE2LANG[src_lang],
|
| 234 |
+
"domain": DOMAIN_MAPPING_REVERSED[domain],
|
| 235 |
+
"translation": generated_translation,
|
| 236 |
+
}
|
| 237 |
|
| 238 |
|
| 239 |
def translate_with_all_models(selected_models, text, tgt_lang, src_lang, domain):
|
|
|
|
| 255 |
|
| 256 |
for model_id in selected_models:
|
| 257 |
i = MODEL_INDEX[model_id]
|
| 258 |
+
for model_output in translate_with_model(
|
| 259 |
+
model_id, text, tgt_lang, src_lang, domain
|
| 260 |
+
):
|
| 261 |
+
outputs[i * 3] = model_output["translation"]
|
| 262 |
+
outputs[i * 3 + 1] = model_output["source_lang"]
|
| 263 |
+
outputs[i * 3 + 2] = model_output["domain"]
|
| 264 |
+
yield outputs
|
| 265 |
|
| 266 |
|
| 267 |
with gr.Blocks() as demo:
|