import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM import torch import itertools DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") MODEL_IDS = [ "70M", "160M", "410M", "610M", ] MODEL_MAPPING = { model_id: f"LinguaCustodia/multilingual-multidomain-fin-mt-{model_id}" for model_id in MODEL_IDS } MODEL_INDEX = {m: i for i, m in enumerate(MODEL_IDS)} TOKENIZER = AutoTokenizer.from_pretrained( MODEL_MAPPING["70M"], pad_token="", mask_token="", eos_token="", padding_side="left", max_position_embeddings=512, model_max_length=512, ) MODELS = { model_name: AutoModelForCausalLM.from_pretrained( model_id, max_position_embeddings=512, device_map=DEVICE, torch_dtype=torch.bfloat16, ) for model_name, model_id in MODEL_MAPPING.items() } DOMAINS = [ "Auto", "Asset management", "Annual report", "Corporate action", "Equity research", "Fund fact sheet", "Kiid", "Life insurance", "Regulatory", "General", ] DOMAIN_MAPPING = { "Auto": None, "Asset management": "am", "Annual report": "ar", "Corporate action": "corporateAction", "Equity research": "equi", "Fund fact sheet": "ffs", "Kiid": "kiid", "Life insurance": "lifeInsurance", "Regulatory": "regulatory", "General": "general", } DOMAIN_MAPPING_REVERSED = {v: k for k, v in DOMAIN_MAPPING.items()} LANG2CODE = { "English": "en", "German": "de", "Spanish": "es", "French": "fr", "Italian": "it", "Dutch": "nl", "Swedish": "sv", "Portuguese": "pt", } CODE2LANG = {v: k for k, v in LANG2CODE.items()} LANGUAGES = sorted(LANG2CODE.keys()) def language_token(lang): return f"" def domain_token(dom): return f"" def language_token_to_str(token): return token[6:-1] def domain_token_to_str(token): return token[5:-1] def format_input(src, tgt_lang, src_lang, domain): tgt_lang_token = language_token(tgt_lang) prefix = TOKENIZER.eos_token base_input = f"{prefix}{src}{tgt_lang_token}" if src_lang is None: return base_input else: src_lang_token = language_token(src_lang) base_input = f"{base_input}{src_lang_token}" if domain is None: return base_input else: dom_token = domain_token(domain) base_input = f"{base_input}{dom_token}" return base_input def translate_with_model(model_name, text, tgt_lang, src_lang, domain): model = MODELS[model_name] formatted_text = format_input(text, tgt_lang, src_lang, domain) inputs = TOKENIZER(formatted_text, return_tensors="pt", return_token_type_ids=False) for k, v in inputs.items(): inputs[k] = v.to(DEVICE) if src_lang is None: domain_token_pos = inputs["input_ids"].size(1) + 1 elif domain is None: domain_token_pos = inputs["input_ids"].size(1) else: domain_token_pos = inputs["input_ids"].size(1) - 1 src_lang_token_pos = domain_token_pos - 1 _tgt_lang_token_pos = src_lang_token_pos - 1 outputs = model.generate( **inputs, num_beams=5, length_penalty=0.65, max_new_tokens=500, pad_token_id=TOKENIZER.pad_token_id, eos_token_id=TOKENIZER.eos_token_id, ) generated_translation = TOKENIZER.decode( outputs[0, domain_token_pos + 1 :], skip_special_tokens=True ) source_language_token = TOKENIZER.convert_ids_to_tokens( outputs[0, src_lang_token_pos].item() ) domain_token = TOKENIZER.convert_ids_to_tokens(outputs[0, domain_token_pos].item()) return { "model": model_name, "source_lang": CODE2LANG[language_token_to_str(source_language_token)], "domain": DOMAIN_MAPPING_REVERSED[domain_token_to_str(domain_token)], "translation": generated_translation, } def translate_with_all_models(selected_models, text, tgt_lang, src_lang, domain): tgt_lang = LANG2CODE[tgt_lang] src_lang = None if src_lang == "Auto" else LANG2CODE.get(src_lang) domain = DOMAIN_MAPPING[domain] outputs = [None] * (3 * len(MODEL_IDS)) outputs = list( itertools.chain.from_iterable( ( ["Processing..."] * 3 if model_id in selected_models else ["This model is disabled"] * 3 ) for model_id in MODEL_IDS ) ) for model_id in selected_models: i = MODEL_INDEX[model_id] model_output = translate_with_model(model_id, text, tgt_lang, src_lang, domain) outputs[i * 3] = model_output["translation"] outputs[i * 3 + 1] = model_output["source_lang"] outputs[i * 3 + 2] = model_output["domain"] yield outputs with gr.Blocks() as demo: with gr.Row(variant="default"): title = "🌐 Multilingual Multidomain Financial Translator" description = """

Specialized Translation for Financial Documents across 8 Languages and 9 Domains

""" gr.HTML(f"

{title}

\n

{description}

") with gr.Row(variant="panel"): with gr.Column(variant="default"): selected_models = gr.CheckboxGroup( choices=MODEL_IDS, value=MODEL_IDS, type="value", label="Models", container=True, ) source_text = gr.Textbox(lines=3, label="Source sentence") with gr.Column(variant="default"): source_language = gr.Dropdown( LANGUAGES + ["Auto"], value="Auto", label="Source language" ) target_language = gr.Dropdown( LANGUAGES, value="French", label="Target language" ) with gr.Column(variant="default"): domain = gr.Radio(DOMAINS, value="Auto", label="Domain") with gr.Row(): translate_btn = gr.Button("Translate", variant="primary") with gr.Row(variant="panel"): outputs = {} for model_id in MODEL_IDS: with gr.Tab(model_id): outputs[model_id] = { "translation": gr.Textbox( lines=2, label="Translation", container=True ), "source_lang": gr.Textbox( label="Predicted source language", info='This is the predicted source language, if "Auto" is selected.', container=True, ), "domain": gr.Textbox( label="Predicted domain", info='This is the predicted domain, if "Auto" is checked.', container=True, ), } gr.HTML( f"

Model: LinguaCustodia/multilingual-multidomain-fin-mt-{model_id}

" ) with gr.Row(variant="panel"): gr.HTML( """

Please cite this work as:\n\n

@inproceedings{DBLP:conf/wmt/CaillautNQLB24,
  author       = {Ga{\"{e}}tan Caillaut and
                  Mariam Nakhl{\'{e}} and
                  Raheel Qader and
                  Jingshu Liu and
                  Jean{-}Gabriel Barthelemy},
  title        = {Scaling Laws of Decoder-Only Models on the Multilingual Machine Translation Task},
  booktitle    = {{WMT}},
  pages        = {1318--1331},
  publisher    = {Association for Computational Linguistics},
  year         = {2024}
}

""" ) translate_btn.click( fn=translate_with_all_models, inputs=[selected_models, source_text, target_language, source_language, domain], outputs=list( itertools.chain.from_iterable( [outputs[model_id][k] for k in ("translation", "source_lang", "domain")] for model_id in MODEL_IDS ) ), ) demo.launch()