gcaillaut commited on
Commit
f9273cb
·
1 Parent(s): 5437ff2

implement streaming

Browse files
Files changed (1) hide show
  1. app.py +24 -24
app.py CHANGED
@@ -1,15 +1,16 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from transformers.cache_utils import DynamicCache
4
  import torch
5
  import itertools
 
6
 
7
  DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
8
  MODEL_IDS = [
9
  "70M",
10
  "160M",
11
  "410M",
12
- "610M",
13
  ]
14
  MODEL_MAPPING = {
15
  model_id: f"LinguaCustodia/multilingual-multidomain-fin-mt-{model_id}"
@@ -211,31 +212,28 @@ def translate_with_model(model_name, text, tgt_lang, src_lang, domain):
211
  src_lang_token_pos = domain_token_pos - 1
212
  _tgt_lang_token_pos = src_lang_token_pos - 1
213
 
214
- outputs = model.generate(
 
215
  input_ids=inputs["input_ids"],
216
  attention_mask=inputs["attention_mask"],
217
  num_beams=1,
218
  max_new_tokens=500,
219
- pad_token_id=TOKENIZER.pad_token_id,
220
- eos_token_id=TOKENIZER.eos_token_id,
221
  past_key_values=past_key_values,
 
222
  )
 
 
223
 
224
- generated_translation = TOKENIZER.decode(
225
- outputs[0, domain_token_pos + 1 :], skip_special_tokens=True
226
- )
227
-
228
- source_language_token = TOKENIZER.convert_ids_to_tokens(
229
- outputs[0, src_lang_token_pos].item()
230
- )
231
- dom_token = TOKENIZER.convert_ids_to_tokens(outputs[0, domain_token_pos].item())
232
 
233
- return {
234
- "model": model_name,
235
- "source_lang": CODE2LANG[language_token_to_str(source_language_token)],
236
- "domain": DOMAIN_MAPPING_REVERSED[domain_token_to_str(dom_token)],
237
- "translation": generated_translation,
238
- }
239
 
240
 
241
  def translate_with_all_models(selected_models, text, tgt_lang, src_lang, domain):
@@ -257,11 +255,13 @@ def translate_with_all_models(selected_models, text, tgt_lang, src_lang, domain)
257
 
258
  for model_id in selected_models:
259
  i = MODEL_INDEX[model_id]
260
- model_output = translate_with_model(model_id, text, tgt_lang, src_lang, domain)
261
- outputs[i * 3] = model_output["translation"]
262
- outputs[i * 3 + 1] = model_output["source_lang"]
263
- outputs[i * 3 + 2] = model_output["domain"]
264
- yield outputs
 
 
265
 
266
 
267
  with gr.Blocks() as demo:
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
3
  from transformers.cache_utils import DynamicCache
4
  import torch
5
  import itertools
6
+ from threading import Thread
7
 
8
  DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
9
  MODEL_IDS = [
10
  "70M",
11
  "160M",
12
  "410M",
13
+ # "610M",
14
  ]
15
  MODEL_MAPPING = {
16
  model_id: f"LinguaCustodia/multilingual-multidomain-fin-mt-{model_id}"
 
212
  src_lang_token_pos = domain_token_pos - 1
213
  _tgt_lang_token_pos = src_lang_token_pos - 1
214
 
215
+ streamer = TextIteratorStreamer(TOKENIZER, skip_prompt=True)
216
+ generation_kwargs = dict(
217
  input_ids=inputs["input_ids"],
218
  attention_mask=inputs["attention_mask"],
219
  num_beams=1,
220
  max_new_tokens=500,
 
 
221
  past_key_values=past_key_values,
222
+ streamer=streamer,
223
  )
224
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
225
+ thread.start()
226
 
227
+ generated_translation = ""
228
+ for new_text in streamer:
229
+ generated_translation += new_text.replace(TOKENIZER.eos_token, "")
 
 
 
 
 
230
 
231
+ yield {
232
+ "model": model_name,
233
+ "source_lang": CODE2LANG[src_lang],
234
+ "domain": DOMAIN_MAPPING_REVERSED[domain],
235
+ "translation": generated_translation,
236
+ }
237
 
238
 
239
  def translate_with_all_models(selected_models, text, tgt_lang, src_lang, domain):
 
255
 
256
  for model_id in selected_models:
257
  i = MODEL_INDEX[model_id]
258
+ for model_output in translate_with_model(
259
+ model_id, text, tgt_lang, src_lang, domain
260
+ ):
261
+ outputs[i * 3] = model_output["translation"]
262
+ outputs[i * 3 + 1] = model_output["source_lang"]
263
+ outputs[i * 3 + 2] = model_output["domain"]
264
+ yield outputs
265
 
266
 
267
  with gr.Blocks() as demo: