gcaillaut commited on
Commit
91126af
·
1 Parent(s): af8c6cf

Models can be disabled

Browse files
Files changed (1) hide show
  1. app.py +28 -4
app.py CHANGED
@@ -14,6 +14,7 @@ MODEL_MAPPING = {
14
  model_id: f"LinguaCustodia/multilingual-multidomain-fin-mt-{model_id}"
15
  for model_id in MODEL_IDS
16
  }
 
17
  TOKENIZER = AutoTokenizer.from_pretrained(
18
  MODEL_MAPPING["70M"],
19
  pad_token="<pad>",
@@ -154,13 +155,25 @@ def translate_with_model(model_name, text, tgt_lang, src_lang, domain):
154
  }
155
 
156
 
157
- def translate_with_all_models(text, tgt_lang, src_lang, domain):
158
  tgt_lang = LANG2CODE[tgt_lang]
159
  src_lang = None if src_lang == "Auto" else LANG2CODE.get(src_lang)
160
  domain = DOMAIN_MAPPING[domain]
161
 
162
  outputs = [None] * (3 * len(MODEL_IDS))
163
- for i, model_id in enumerate(MODEL_IDS):
 
 
 
 
 
 
 
 
 
 
 
 
164
  model_output = translate_with_model(model_id, text, tgt_lang, src_lang, domain)
165
  outputs[i * 3] = model_output["translation"]
166
  outputs[i * 3 + 1] = model_output["source_lang"]
@@ -176,6 +189,13 @@ with gr.Blocks() as demo:
176
 
177
  with gr.Row(variant="panel"):
178
  with gr.Column(variant="default"):
 
 
 
 
 
 
 
179
  source_text = gr.Textbox(lines=3, label="Source sentence")
180
  with gr.Column(variant="default"):
181
  source_language = gr.Dropdown(
@@ -195,14 +215,18 @@ with gr.Blocks() as demo:
195
  for model_id in MODEL_IDS:
196
  with gr.Tab(model_id):
197
  outputs[model_id] = {
198
- "translation": gr.Textbox(lines=2, label="Translation"),
 
 
199
  "source_lang": gr.Textbox(
200
  label="Predicted source language",
201
  info='This is the predicted source language, if "Auto" is selected.',
 
202
  ),
203
  "domain": gr.Textbox(
204
  label="Predicted domain",
205
  info='This is the predicted domain, if "Auto" is checked.',
 
206
  ),
207
  }
208
  gr.HTML(
@@ -227,7 +251,7 @@ with gr.Blocks() as demo:
227
 
228
  translate_btn.click(
229
  fn=translate_with_all_models,
230
- inputs=[source_text, target_language, source_language, domain],
231
  outputs=list(
232
  itertools.chain.from_iterable(
233
  [outputs[model_id][k] for k in ("translation", "source_lang", "domain")]
 
14
  model_id: f"LinguaCustodia/multilingual-multidomain-fin-mt-{model_id}"
15
  for model_id in MODEL_IDS
16
  }
17
+ MODEL_INDEX = {m: i for i, m in enumerate(MODEL_IDS)}
18
  TOKENIZER = AutoTokenizer.from_pretrained(
19
  MODEL_MAPPING["70M"],
20
  pad_token="<pad>",
 
155
  }
156
 
157
 
158
+ def translate_with_all_models(selected_models, text, tgt_lang, src_lang, domain):
159
  tgt_lang = LANG2CODE[tgt_lang]
160
  src_lang = None if src_lang == "Auto" else LANG2CODE.get(src_lang)
161
  domain = DOMAIN_MAPPING[domain]
162
 
163
  outputs = [None] * (3 * len(MODEL_IDS))
164
+ outputs = list(
165
+ itertools.chain.from_iterable(
166
+ (
167
+ ["Processing..."] * 3
168
+ if model_id in selected_models
169
+ else ["This model is disabled"] * 3
170
+ )
171
+ for model_id in MODEL_IDS
172
+ )
173
+ )
174
+
175
+ for model_id in selected_models:
176
+ i = MODEL_INDEX[model_id]
177
  model_output = translate_with_model(model_id, text, tgt_lang, src_lang, domain)
178
  outputs[i * 3] = model_output["translation"]
179
  outputs[i * 3 + 1] = model_output["source_lang"]
 
189
 
190
  with gr.Row(variant="panel"):
191
  with gr.Column(variant="default"):
192
+ selected_models = gr.CheckboxGroup(
193
+ choices=MODEL_IDS,
194
+ value=MODEL_IDS,
195
+ type="value",
196
+ label="Models",
197
+ container=True,
198
+ )
199
  source_text = gr.Textbox(lines=3, label="Source sentence")
200
  with gr.Column(variant="default"):
201
  source_language = gr.Dropdown(
 
215
  for model_id in MODEL_IDS:
216
  with gr.Tab(model_id):
217
  outputs[model_id] = {
218
+ "translation": gr.Textbox(
219
+ lines=2, label="Translation", container=True
220
+ ),
221
  "source_lang": gr.Textbox(
222
  label="Predicted source language",
223
  info='This is the predicted source language, if "Auto" is selected.',
224
+ container=True,
225
  ),
226
  "domain": gr.Textbox(
227
  label="Predicted domain",
228
  info='This is the predicted domain, if "Auto" is checked.',
229
+ container=True,
230
  ),
231
  }
232
  gr.HTML(
 
251
 
252
  translate_btn.click(
253
  fn=translate_with_all_models,
254
+ inputs=[selected_models, source_text, target_language, source_language, domain],
255
  outputs=list(
256
  itertools.chain.from_iterable(
257
  [outputs[model_id][k] for k in ("translation", "source_lang", "domain")]