Spaces:
Paused
Paused
Models can be disabled
Browse files
app.py
CHANGED
|
@@ -14,6 +14,7 @@ MODEL_MAPPING = {
|
|
| 14 |
model_id: f"LinguaCustodia/multilingual-multidomain-fin-mt-{model_id}"
|
| 15 |
for model_id in MODEL_IDS
|
| 16 |
}
|
|
|
|
| 17 |
TOKENIZER = AutoTokenizer.from_pretrained(
|
| 18 |
MODEL_MAPPING["70M"],
|
| 19 |
pad_token="<pad>",
|
|
@@ -154,13 +155,25 @@ def translate_with_model(model_name, text, tgt_lang, src_lang, domain):
|
|
| 154 |
}
|
| 155 |
|
| 156 |
|
| 157 |
-
def translate_with_all_models(text, tgt_lang, src_lang, domain):
|
| 158 |
tgt_lang = LANG2CODE[tgt_lang]
|
| 159 |
src_lang = None if src_lang == "Auto" else LANG2CODE.get(src_lang)
|
| 160 |
domain = DOMAIN_MAPPING[domain]
|
| 161 |
|
| 162 |
outputs = [None] * (3 * len(MODEL_IDS))
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
model_output = translate_with_model(model_id, text, tgt_lang, src_lang, domain)
|
| 165 |
outputs[i * 3] = model_output["translation"]
|
| 166 |
outputs[i * 3 + 1] = model_output["source_lang"]
|
|
@@ -176,6 +189,13 @@ with gr.Blocks() as demo:
|
|
| 176 |
|
| 177 |
with gr.Row(variant="panel"):
|
| 178 |
with gr.Column(variant="default"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
source_text = gr.Textbox(lines=3, label="Source sentence")
|
| 180 |
with gr.Column(variant="default"):
|
| 181 |
source_language = gr.Dropdown(
|
|
@@ -195,14 +215,18 @@ with gr.Blocks() as demo:
|
|
| 195 |
for model_id in MODEL_IDS:
|
| 196 |
with gr.Tab(model_id):
|
| 197 |
outputs[model_id] = {
|
| 198 |
-
"translation": gr.Textbox(
|
|
|
|
|
|
|
| 199 |
"source_lang": gr.Textbox(
|
| 200 |
label="Predicted source language",
|
| 201 |
info='This is the predicted source language, if "Auto" is selected.',
|
|
|
|
| 202 |
),
|
| 203 |
"domain": gr.Textbox(
|
| 204 |
label="Predicted domain",
|
| 205 |
info='This is the predicted domain, if "Auto" is checked.',
|
|
|
|
| 206 |
),
|
| 207 |
}
|
| 208 |
gr.HTML(
|
|
@@ -227,7 +251,7 @@ with gr.Blocks() as demo:
|
|
| 227 |
|
| 228 |
translate_btn.click(
|
| 229 |
fn=translate_with_all_models,
|
| 230 |
-
inputs=[source_text, target_language, source_language, domain],
|
| 231 |
outputs=list(
|
| 232 |
itertools.chain.from_iterable(
|
| 233 |
[outputs[model_id][k] for k in ("translation", "source_lang", "domain")]
|
|
|
|
| 14 |
model_id: f"LinguaCustodia/multilingual-multidomain-fin-mt-{model_id}"
|
| 15 |
for model_id in MODEL_IDS
|
| 16 |
}
|
| 17 |
+
MODEL_INDEX = {m: i for i, m in enumerate(MODEL_IDS)}
|
| 18 |
TOKENIZER = AutoTokenizer.from_pretrained(
|
| 19 |
MODEL_MAPPING["70M"],
|
| 20 |
pad_token="<pad>",
|
|
|
|
| 155 |
}
|
| 156 |
|
| 157 |
|
| 158 |
+
def translate_with_all_models(selected_models, text, tgt_lang, src_lang, domain):
|
| 159 |
tgt_lang = LANG2CODE[tgt_lang]
|
| 160 |
src_lang = None if src_lang == "Auto" else LANG2CODE.get(src_lang)
|
| 161 |
domain = DOMAIN_MAPPING[domain]
|
| 162 |
|
| 163 |
outputs = [None] * (3 * len(MODEL_IDS))
|
| 164 |
+
outputs = list(
|
| 165 |
+
itertools.chain.from_iterable(
|
| 166 |
+
(
|
| 167 |
+
["Processing..."] * 3
|
| 168 |
+
if model_id in selected_models
|
| 169 |
+
else ["This model is disabled"] * 3
|
| 170 |
+
)
|
| 171 |
+
for model_id in MODEL_IDS
|
| 172 |
+
)
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
for model_id in selected_models:
|
| 176 |
+
i = MODEL_INDEX[model_id]
|
| 177 |
model_output = translate_with_model(model_id, text, tgt_lang, src_lang, domain)
|
| 178 |
outputs[i * 3] = model_output["translation"]
|
| 179 |
outputs[i * 3 + 1] = model_output["source_lang"]
|
|
|
|
| 189 |
|
| 190 |
with gr.Row(variant="panel"):
|
| 191 |
with gr.Column(variant="default"):
|
| 192 |
+
selected_models = gr.CheckboxGroup(
|
| 193 |
+
choices=MODEL_IDS,
|
| 194 |
+
value=MODEL_IDS,
|
| 195 |
+
type="value",
|
| 196 |
+
label="Models",
|
| 197 |
+
container=True,
|
| 198 |
+
)
|
| 199 |
source_text = gr.Textbox(lines=3, label="Source sentence")
|
| 200 |
with gr.Column(variant="default"):
|
| 201 |
source_language = gr.Dropdown(
|
|
|
|
| 215 |
for model_id in MODEL_IDS:
|
| 216 |
with gr.Tab(model_id):
|
| 217 |
outputs[model_id] = {
|
| 218 |
+
"translation": gr.Textbox(
|
| 219 |
+
lines=2, label="Translation", container=True
|
| 220 |
+
),
|
| 221 |
"source_lang": gr.Textbox(
|
| 222 |
label="Predicted source language",
|
| 223 |
info='This is the predicted source language, if "Auto" is selected.',
|
| 224 |
+
container=True,
|
| 225 |
),
|
| 226 |
"domain": gr.Textbox(
|
| 227 |
label="Predicted domain",
|
| 228 |
info='This is the predicted domain, if "Auto" is checked.',
|
| 229 |
+
container=True,
|
| 230 |
),
|
| 231 |
}
|
| 232 |
gr.HTML(
|
|
|
|
| 251 |
|
| 252 |
translate_btn.click(
|
| 253 |
fn=translate_with_all_models,
|
| 254 |
+
inputs=[selected_models, source_text, target_language, source_language, domain],
|
| 255 |
outputs=list(
|
| 256 |
itertools.chain.from_iterable(
|
| 257 |
[outputs[model_id][k] for k in ("translation", "source_lang", "domain")]
|