Jingshu's picture
Update app.py
3ba04e0 verified
raw
history blame
11.6 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from transformers.cache_utils import DynamicCache
import torch
import itertools
from threading import Thread
import spaces
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
MODEL_IDS = [
"70M",
"160M",
"410M",
"Bronze",
"Silver",
"Gold"
]
MODEL_MAPPING = {
model_id: f"LinguaCustodia/FinTranslate-{model_id}"
for model_id in MODEL_IDS
}
MODEL_INDEX = {m: i for i, m in enumerate(MODEL_IDS)}
TOKENIZER = AutoTokenizer.from_pretrained(
MODEL_MAPPING["70M"],
pad_token="<pad>",
mask_token="<mask>",
eos_token="<eos>",
padding_side="left",
max_position_embeddings=512,
model_max_length=512,
)
MODELS = {
model_name: AutoModelForCausalLM.from_pretrained(
model_id,
max_position_embeddings=512,
device_map=DEVICE,
torch_dtype=torch.bfloat16,
)
for model_name, model_id in MODEL_MAPPING.items()
}
DOMAINS = [
"Auto",
"Asset management marketing",
"Annual report",
"Corporate action",
"Equity research",
"Fund fact sheet",
"Kiid",
"Life insurance",
"Regulatory",
"General",
]
DOMAIN_MAPPING = {
"Auto": None,
"Asset management marketing": "am",
"Annual report": "ar",
"Corporate action": "corporateAction",
"Equity research": "equi",
"Fund fact sheet": "ffs",
"Kiid": "kiid",
"Life insurance": "lifeInsurance",
"Regulatory": "regulatory",
"General": "general",
}
DOMAIN_MAPPING_REVERSED = {v: k for k, v in DOMAIN_MAPPING.items()}
LANG2CODE = {
"English": "en",
"German": "de",
"Spanish": "es",
"French": "fr",
"Italian": "it",
"Dutch": "nl",
"Swedish": "sv",
"Portuguese": "pt",
}
CODE2LANG = {v: k for k, v in LANG2CODE.items()}
LANGUAGES = sorted(LANG2CODE.keys())
def build_language_token(lang):
return f"<lang_{lang}>"
def build_domain_token(dom):
return f"<dom_{dom}>"
def language_token_to_str(token):
return token[6:-1]
def domain_token_to_str(token):
return token[5:-1]
def format_input(src, tgt_lang, src_lang, domain):
tgt_lang_token = build_language_token(tgt_lang)
prefix = TOKENIZER.eos_token
base_input = f"{prefix}{src}</src>{tgt_lang_token}"
if src_lang is None:
return base_input
else:
src_lang_token = build_language_token(src_lang)
base_input = f"{base_input}{src_lang_token}"
if domain is None:
return base_input
else:
dom_token = build_domain_token(domain)
base_input = f"{base_input}{dom_token}"
return base_input
@spaces.GPU(duration=120)
def translate_with_model(model_name, text, tgt_lang, src_lang, domain):
model = MODELS[model_name]
formatted_text = format_input(text, tgt_lang, src_lang, domain)
inputs = TOKENIZER(
formatted_text,
return_attention_mask=True,
return_tensors="pt",
return_token_type_ids=False,
)
for k, v in inputs.items():
inputs[k] = v.to(DEVICE)
src_lang_provided = src_lang is not None
domain_provided = domain is not None
need_format_again = not (src_lang_provided and domain_provided)
past_key_values = DynamicCache()
cache_position = torch.arange(
inputs["input_ids"].size(1), dtype=torch.int64, device=DEVICE
)
if not src_lang_provided:
# Need to predict src lang
with torch.inference_mode():
outputs = model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
use_cache=True,
past_key_values=past_key_values,
cache_position=cache_position,
)
src_lang_token_id = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(0)
src_lang = language_token_to_str(
TOKENIZER.convert_ids_to_tokens(src_lang_token_id.squeeze().item())
)
cache_position = cache_position[-1:] + 1
attention_mask = inputs["attention_mask"]
attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.size(0), 1))],
dim=-1,
)
inputs = {"input_ids": src_lang_token_id, "attention_mask": attention_mask}
if not domain_provided:
# Need to predict domain
with torch.inference_mode():
outputs = model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
use_cache=True,
past_key_values=past_key_values,
)
domain_token_id = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(0)
domain = domain_token_to_str(
TOKENIZER.convert_ids_to_tokens(domain_token_id.squeeze().item())
)
cache_position = cache_position[-1:] + 1
attention_mask = inputs["attention_mask"]
attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.size(0), 1))],
dim=-1,
)
inputs = {"input_ids": domain_token_id, "attention_mask": attention_mask}
elif not src_lang_provided:
# in this case, src_lang was not provided, but domain was.
# So we still need to run a forward pass to build the kv cache for the domain token
dom_token = build_domain_token(domain)
# dom_token = "<dom_general>"
domain = domain_token_to_str(dom_token)
domain_token_id = TOKENIZER.convert_tokens_to_ids(dom_token)
inputs["input_ids"] = torch.hstack(
[inputs["input_ids"], torch.tensor([[domain_token_id]], device=DEVICE)]
)
inputs["attention_mask"] = torch.hstack(
[inputs["attention_mask"], inputs["attention_mask"].new_ones((1, 1))]
)
cache_position = torch.hstack([cache_position, cache_position[-1:] + 1])
if need_format_again:
formatted_text = format_input(text, tgt_lang, src_lang, domain)
inputs = TOKENIZER(
formatted_text,
return_attention_mask=True,
return_tensors="pt",
return_token_type_ids=False,
)
for k, v in inputs.items():
inputs[k] = v.to(DEVICE)
domain_token_pos = inputs["input_ids"].size(1) - 1
src_lang_token_pos = domain_token_pos - 1
_tgt_lang_token_pos = src_lang_token_pos - 1
streamer = TextIteratorStreamer(TOKENIZER, skip_prompt=True)
generation_kwargs = dict(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
num_beams=1,
max_new_tokens=500,
past_key_values=past_key_values,
streamer=streamer,
eos_token_id=TOKENIZER.eos_token_id,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
generated_translation = ""
for new_text in streamer:
generated_translation += new_text.replace(TOKENIZER.eos_token, "")
yield {
"model": model_name,
"source_lang": CODE2LANG[src_lang],
"domain": DOMAIN_MAPPING_REVERSED[domain],
"translation": generated_translation,
}
@spaces.GPU(duration=120)
def translate_with_all_models(selected_models, text, tgt_lang, src_lang, domain):
tgt_lang = LANG2CODE[tgt_lang]
src_lang = None if src_lang == "Auto" else LANG2CODE.get(src_lang)
domain = DOMAIN_MAPPING[domain]
outputs = [None] * (3 * len(MODEL_IDS))
outputs = list(
itertools.chain.from_iterable(
(
["Processing..."] * 3
if model_id in selected_models
else ["This model is disabled"] * 3
)
for model_id in MODEL_IDS
)
)
for model_id in selected_models:
i = MODEL_INDEX[model_id]
for model_output in translate_with_model(
model_id, text, tgt_lang, src_lang, domain
):
outputs[i * 3] = model_output["translation"]
outputs[i * 3 + 1] = model_output["source_lang"]
outputs[i * 3 + 2] = model_output["domain"]
yield outputs
with gr.Blocks() as demo:
with gr.Row(variant="default"):
title = "🌐 Multilingual Multidomain Financial Translator"
description = """<p>Specialized Translation for Financial Texts across 8 Languages and 9 Domains</p>"""
gr.HTML(f"<h1>{title}</h1>\n<p>{description}</p>")
with gr.Row(variant="panel"):
with gr.Column(variant="default"):
selected_models = gr.CheckboxGroup(
choices=MODEL_IDS,
value=MODEL_IDS,
type="value",
label="Models",
container=True,
)
source_text = gr.Textbox(lines=3, label="Source sentence")
with gr.Column(variant="default"):
source_language = gr.Dropdown(
LANGUAGES + ["Auto"], value="Auto", label="Source language"
)
target_language = gr.Dropdown(
LANGUAGES, value="French", label="Target language"
)
with gr.Column(variant="default"):
domain = gr.Radio(DOMAINS, value="Auto", label="Domain")
with gr.Row():
translate_btn = gr.Button("Translate", variant="primary")
with gr.Row(variant="panel"):
outputs = {}
for model_id in MODEL_IDS:
with gr.Tab(model_id):
outputs[model_id] = {
"translation": gr.Textbox(
lines=2, label="Translation", container=True
),
"source_lang": gr.Textbox(
label="Predicted source language",
info='This is the predicted source language, if "Auto" is selected.',
container=True,
),
"domain": gr.Textbox(
label="Predicted domain",
info='This is the predicted domain, if "Auto" is checked.',
container=True,
),
}
gr.HTML(
f"<p>Model: <a href='https://huggingface.co/LinguaCustodia/multilingual-multidomain-fin-mt-{model_id}' target='_blank'>LinguaCustodia/multilingual-multidomain-fin-mt-{model_id}</a></p>"
)
with gr.Row(variant="panel"):
gr.HTML(
"""<p><strong>Please cite this work as:</strong>\n\n<pre>@inproceedings{DBLP:conf/wmt/CaillautNQLB24,
author = {Ga{\"{e}}tan Caillaut and
Mariam Nakhl{\'{e}} and
Raheel Qader and
Jingshu Liu and
Jean{-}Gabriel Barthelemy},
title = {Scaling Laws of Decoder-Only Models on the Multilingual Machine Translation Task},
booktitle = {{WMT}},
pages = {1318--1331},
publisher = {Association for Computational Linguistics},
year = {2024}
}</pre></p>"""
)
translate_btn.click(
fn=translate_with_all_models,
inputs=[selected_models, source_text, target_language, source_language, domain],
outputs=list(
itertools.chain.from_iterable(
[outputs[model_id][k] for k in ("translation", "source_lang", "domain")]
for model_id in MODEL_IDS
)
),
)
demo.launch()