gcaillaut's picture
update app.py
e1b65d6
raw
history blame
7.35 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import itertools
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
MODEL_IDS = [
"70M",
# "160M",
]
MODEL_MAPPING = {
model_id: f"LinguaCustodia/multilingual-multidomain-fin-mt-{model_id}"
for model_id in MODEL_IDS
}
TOKENIZER = AutoTokenizer.from_pretrained(
MODEL_MAPPING["70M"],
pad_token="<pad>",
mask_token="<mask>",
eos_token="<eos>",
padding_side="left",
max_position_embeddings=512,
model_max_length=512,
)
MODELS = {
model_name: AutoModelForCausalLM.from_pretrained(
model_id,
max_position_embeddings=512,
device_map=DEVICE,
torch_dtype=torch.bfloat16,
)
for model_name, model_id in MODEL_MAPPING.items()
}
DOMAINS = [
"Auto",
"Asset manangement",
"Annual report",
"Corporate action",
"Equity research",
"Fund fact sheet",
"Kiid",
"Life insurance",
"Regulatory",
"General",
]
DOMAIN_MAPPING = {
"Auto": None,
"Asset management": "am",
"Annual report": "ar",
"Corporate action": "corporateAction",
"Equity research": "equi",
"Fund fact sheet": "ffs",
"Kiid": "kiid",
"Life insurance": "lifeInsurance",
"Regulatory": "regulatory",
"General": "general",
}
DOMAIN_MAPPING_REVERSED = {v: k for k, v in DOMAIN_MAPPING.items()}
LANG2CODE = {
"English": "en",
"German": "de",
"Spanish": "es",
"French": "fr",
"Italian": "it",
"Dutch": "nl",
"Swedish": "sv",
"Portuguese": "pt",
}
CODE2LANG = {v: k for k, v in LANG2CODE.items()}
LANGUAGES = sorted(LANG2CODE.keys())
def language_token(lang):
return f"<lang_{lang}>"
def domain_token(dom):
return f"<dom_{dom}>"
def language_token_to_str(token):
return token[6:-1]
def domain_token_to_str(token):
return token[5:-1]
def format_input(src, tgt_lang, src_lang, domain):
tgt_lang_token = language_token(tgt_lang)
prefix = TOKENIZER.eos_token
base_input = f"{prefix}{src}</src>{tgt_lang_token}"
if src_lang is None:
return base_input
else:
src_lang_token = language_token(src_lang)
base_input = f"{base_input}{src_lang_token}"
if domain is None:
return base_input
else:
dom_token = domain_token(domain)
base_input = f"{base_input}{dom_token}"
return base_input
def translate_with_model(model_name, text, tgt_lang, src_lang, domain):
model = MODELS[model_name]
formatted_text = format_input(text, tgt_lang, src_lang, domain)
inputs = TOKENIZER(formatted_text, return_tensors="pt", return_token_type_ids=False)
for k, v in inputs.items():
inputs[k] = v.to(DEVICE)
if src_lang is None:
domain_token_pos = inputs["input_ids"].size(1) + 1
elif domain is None:
domain_token_pos = inputs["input_ids"].size(1)
else:
domain_token_pos = inputs["input_ids"].size(1) - 1
src_lang_token_pos = domain_token_pos - 1
_tgt_lang_token_pos = src_lang_token_pos - 1
outputs = model.generate(
**inputs,
num_beams=5,
length_penalty=0.65,
max_new_tokens=500,
pad_token_id=TOKENIZER.pad_token_id,
eos_token_id=TOKENIZER.eos_token_id,
)
generated_translation = TOKENIZER.decode(
outputs[0, domain_token_pos + 1 :], skip_special_tokens=True
)
source_language_token = TOKENIZER.convert_ids_to_tokens(
outputs[0, src_lang_token_pos].item()
)
domain_token = TOKENIZER.convert_ids_to_tokens(outputs[0, domain_token_pos].item())
return {
"model": model_name,
"source_lang": CODE2LANG[language_token_to_str(source_language_token)],
"domain": DOMAIN_MAPPING_REVERSED[domain_token_to_str(domain_token)],
"translation": generated_translation,
}
def translate_with_all_models(text, tgt_lang, src_lang, domain):
tgt_lang = LANG2CODE[tgt_lang]
src_lang = None if src_lang == "Auto" else LANG2CODE.get(src_lang)
domain = DOMAIN_MAPPING[domain]
res = {
model_name: translate_with_model(model_name, text, tgt_lang, src_lang, domain)
for model_name in MODEL_IDS
}
return list(
itertools.chain.from_iterable(
[res[model_id][k] for k in ("translation", "source_lang", "domain")]
for model_id in MODEL_IDS
)
)
with gr.Blocks() as demo:
with gr.Row(variant="default"):
title = "🌐 Multilingual Multidomain Financial Translator"
description = """<p>Specialized Translation for Financial Documents across 8 Languages and 9 Domains</p>"""
gr.HTML(f"<h1>{title}</h1>\n<p>{description}</p>")
with gr.Row(variant="panel"):
with gr.Column(variant="default"):
source_text = gr.Textbox(lines=3, label="Source sentence")
with gr.Column(variant="default"):
target_language = gr.Dropdown(
LANGUAGES, value="French", label="Target language"
)
source_language = gr.Dropdown(
LANGUAGES + ["Auto"], value="Auto", label="Source language"
)
with gr.Column(variant="default"):
domain = gr.Radio(DOMAINS, value="Auto", label="Domain")
with gr.Row():
translate_btn = gr.Button("Translate", variant="primary")
with gr.Row(variant="panel"):
outputs = {}
for model_id in MODEL_IDS:
with gr.Tab(model_id):
outputs[model_id] = {
"translation": gr.Textbox(lines=2, label="Translation"),
"source_lang": gr.Textbox(
label="Predicted source language",
info='This is the predicted source language, if "Auto" is selected.',
),
"domain": gr.Textbox(
label="Predicted domain",
info='This is the predicted domain, if "Auto" is checked.',
),
}
gr.HTML(
f"<p>Model: <a href='https://huggingface.co/LinguaCustodia/multilingual-multidomain-fin-mt-{model_id}' target='_blank'>LinguaCustodia/multilingual-multidomain-fin-mt-{model_id}</a></p>"
)
with gr.Row(variant="panel"):
gr.HTML(
"""<p><strong>Please cite this work as:</strong>\n\n<pre>@inproceedings{DBLP:conf/wmt/CaillautNQLB24,
author = {Ga{\"{e}}tan Caillaut and
Mariam Nakhl{\'{e}} and
Raheel Qader and
Jingshu Liu and
Jean{-}Gabriel Barthelemy},
title = {Scaling Laws of Decoder-Only Models on the Multilingual Machine Translation Task},
booktitle = {{WMT}},
pages = {1318--1331},
publisher = {Association for Computational Linguistics},
year = {2024}
}</pre></p>"""
)
translate_btn.click(
fn=translate_with_all_models,
inputs=[source_text, target_language, source_language, domain],
outputs=list(
itertools.chain.from_iterable(
[outputs[model_id][k] for k in ("translation", "source_lang", "domain")]
for model_id in MODEL_IDS
)
),
)
demo.launch()