import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer from transformers.cache_utils import DynamicCache import torch import itertools from threading import Thread import spaces DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") MODEL_IDS = [ "70M", "160M", "410M", "Bronze", "Silver", "Gold" ] MODEL_MAPPING = { model_id: f"LinguaCustodia/FinTranslate-{model_id}" for model_id in MODEL_IDS } MODEL_INDEX = {m: i for i, m in enumerate(MODEL_IDS)} TOKENIZER = AutoTokenizer.from_pretrained( MODEL_MAPPING["70M"], pad_token="", mask_token="", eos_token="", padding_side="left", max_position_embeddings=512, model_max_length=512, ) MODELS = { model_name: AutoModelForCausalLM.from_pretrained( model_id, max_position_embeddings=512, device_map=DEVICE, torch_dtype=torch.bfloat16, ) for model_name, model_id in MODEL_MAPPING.items() } DOMAINS = [ "Auto", "Asset management marketing", "Annual report", "Corporate action", "Equity research", "Fund fact sheet", "Kiid", "Life insurance", "Regulatory", "General", ] DOMAIN_MAPPING = { "Auto": None, "Asset management marketing": "am", "Annual report": "ar", "Corporate action": "corporateAction", "Equity research": "equi", "Fund fact sheet": "ffs", "Kiid": "kiid", "Life insurance": "lifeInsurance", "Regulatory": "regulatory", "General": "general", } DOMAIN_MAPPING_REVERSED = {v: k for k, v in DOMAIN_MAPPING.items()} LANG2CODE = { "English": "en", "German": "de", "Spanish": "es", "French": "fr", "Italian": "it", "Dutch": "nl", "Swedish": "sv", "Portuguese": "pt", } CODE2LANG = {v: k for k, v in LANG2CODE.items()} LANGUAGES = sorted(LANG2CODE.keys()) def build_language_token(lang): return f"" def build_domain_token(dom): return f"" def language_token_to_str(token): return token[6:-1] def domain_token_to_str(token): return token[5:-1] def format_input(src, tgt_lang, src_lang, domain): tgt_lang_token = build_language_token(tgt_lang) prefix = TOKENIZER.eos_token base_input = f"{prefix}{src}{tgt_lang_token}" if src_lang is None: return base_input else: src_lang_token = build_language_token(src_lang) base_input = f"{base_input}{src_lang_token}" if domain is None: return base_input else: dom_token = build_domain_token(domain) base_input = f"{base_input}{dom_token}" return base_input @spaces.GPU(duration=120) def translate_with_model(model_name, text, tgt_lang, src_lang, domain): model = MODELS[model_name] formatted_text = format_input(text, tgt_lang, src_lang, domain) inputs = TOKENIZER( formatted_text, return_attention_mask=True, return_tensors="pt", return_token_type_ids=False, ) for k, v in inputs.items(): inputs[k] = v.to(DEVICE) src_lang_provided = src_lang is not None domain_provided = domain is not None need_format_again = not (src_lang_provided and domain_provided) past_key_values = DynamicCache() cache_position = torch.arange( inputs["input_ids"].size(1), dtype=torch.int64, device=DEVICE ) if not src_lang_provided: # Need to predict src lang with torch.inference_mode(): outputs = model( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], use_cache=True, past_key_values=past_key_values, cache_position=cache_position, ) src_lang_token_id = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(0) src_lang = language_token_to_str( TOKENIZER.convert_ids_to_tokens(src_lang_token_id.squeeze().item()) ) cache_position = cache_position[-1:] + 1 attention_mask = inputs["attention_mask"] attention_mask = torch.cat( [attention_mask, attention_mask.new_ones((attention_mask.size(0), 1))], dim=-1, ) inputs = {"input_ids": src_lang_token_id, "attention_mask": attention_mask} if not domain_provided: # Need to predict domain with torch.inference_mode(): outputs = model( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], use_cache=True, past_key_values=past_key_values, ) domain_token_id = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(0) domain = domain_token_to_str( TOKENIZER.convert_ids_to_tokens(domain_token_id.squeeze().item()) ) cache_position = cache_position[-1:] + 1 attention_mask = inputs["attention_mask"] attention_mask = torch.cat( [attention_mask, attention_mask.new_ones((attention_mask.size(0), 1))], dim=-1, ) inputs = {"input_ids": domain_token_id, "attention_mask": attention_mask} elif not src_lang_provided: # in this case, src_lang was not provided, but domain was. # So we still need to run a forward pass to build the kv cache for the domain token dom_token = build_domain_token(domain) # dom_token = "" domain = domain_token_to_str(dom_token) domain_token_id = TOKENIZER.convert_tokens_to_ids(dom_token) inputs["input_ids"] = torch.hstack( [inputs["input_ids"], torch.tensor([[domain_token_id]], device=DEVICE)] ) inputs["attention_mask"] = torch.hstack( [inputs["attention_mask"], inputs["attention_mask"].new_ones((1, 1))] ) cache_position = torch.hstack([cache_position, cache_position[-1:] + 1]) if need_format_again: formatted_text = format_input(text, tgt_lang, src_lang, domain) inputs = TOKENIZER( formatted_text, return_attention_mask=True, return_tensors="pt", return_token_type_ids=False, ) for k, v in inputs.items(): inputs[k] = v.to(DEVICE) domain_token_pos = inputs["input_ids"].size(1) - 1 src_lang_token_pos = domain_token_pos - 1 _tgt_lang_token_pos = src_lang_token_pos - 1 streamer = TextIteratorStreamer(TOKENIZER, skip_prompt=True) generation_kwargs = dict( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], num_beams=1, max_new_tokens=500, past_key_values=past_key_values, streamer=streamer, eos_token_id=TOKENIZER.eos_token_id, ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() generated_translation = "" for new_text in streamer: generated_translation += new_text.replace(TOKENIZER.eos_token, "") yield { "model": model_name, "source_lang": CODE2LANG[src_lang], "domain": DOMAIN_MAPPING_REVERSED[domain], "translation": generated_translation, } @spaces.GPU(duration=120) def translate_with_all_models(selected_models, text, tgt_lang, src_lang, domain): tgt_lang = LANG2CODE[tgt_lang] src_lang = None if src_lang == "Auto" else LANG2CODE.get(src_lang) domain = DOMAIN_MAPPING[domain] outputs = [None] * (3 * len(MODEL_IDS)) outputs = list( itertools.chain.from_iterable( ( ["Processing..."] * 3 if model_id in selected_models else ["This model is disabled"] * 3 ) for model_id in MODEL_IDS ) ) for model_id in selected_models: i = MODEL_INDEX[model_id] for model_output in translate_with_model( model_id, text, tgt_lang, src_lang, domain ): outputs[i * 3] = model_output["translation"] outputs[i * 3 + 1] = model_output["source_lang"] outputs[i * 3 + 2] = model_output["domain"] yield outputs with gr.Blocks() as demo: with gr.Row(variant="default"): title = "🌐 Multilingual Multidomain Financial Translator" description = """

Specialized Translation for Financial Texts across 8 Languages and 9 Domains

""" gr.HTML(f"

{title}

\n

{description}

") with gr.Row(variant="panel"): with gr.Column(variant="default"): selected_models = gr.CheckboxGroup( choices=MODEL_IDS, value=MODEL_IDS, type="value", label="Models", container=True, ) source_text = gr.Textbox(lines=3, label="Source sentence") with gr.Column(variant="default"): source_language = gr.Dropdown( LANGUAGES + ["Auto"], value="Auto", label="Source language" ) target_language = gr.Dropdown( LANGUAGES, value="French", label="Target language" ) with gr.Column(variant="default"): domain = gr.Radio(DOMAINS, value="Auto", label="Domain") with gr.Row(): translate_btn = gr.Button("Translate", variant="primary") with gr.Row(variant="panel"): outputs = {} for model_id in MODEL_IDS: with gr.Tab(model_id): outputs[model_id] = { "translation": gr.Textbox( lines=2, label="Translation", container=True ), "source_lang": gr.Textbox( label="Predicted source language", info='This is the predicted source language, if "Auto" is selected.', container=True, ), "domain": gr.Textbox( label="Predicted domain", info='This is the predicted domain, if "Auto" is checked.', container=True, ), } gr.HTML( f"

Model: LinguaCustodia/FinTranslate-{model_id}

" ) with gr.Row(variant="panel"): gr.HTML( """

Please cite this work as:\n\n

@inproceedings{DBLP:conf/wmt/CaillautNQLB24,
  author       = {Ga{\"{e}}tan Caillaut and
                  Mariam Nakhl{\'{e}} and
                  Raheel Qader and
                  Jingshu Liu and
                  Jean{-}Gabriel Barthelemy},
  title        = {Scaling Laws of Decoder-Only Models on the Multilingual Machine Translation Task},
  booktitle    = {{WMT}},
  pages        = {1318--1331},
  publisher    = {Association for Computational Linguistics},
  year         = {2024}
}

""" ) translate_btn.click( fn=translate_with_all_models, inputs=[selected_models, source_text, target_language, source_language, domain], outputs=list( itertools.chain.from_iterable( [outputs[model_id][k] for k in ("translation", "source_lang", "domain")] for model_id in MODEL_IDS ) ), ) demo.launch()