Spaces:
Paused
Paused
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
| from transformers.cache_utils import DynamicCache | |
| import torch | |
| import itertools | |
| from threading import Thread | |
| import spaces | |
| DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| MODEL_IDS = [ | |
| "70M", | |
| "160M", | |
| "410M", | |
| "Bronze", | |
| "Silver", | |
| "Gold" | |
| ] | |
| MODEL_MAPPING = { | |
| model_id: f"LinguaCustodia/FinTranslate-{model_id}" | |
| for model_id in MODEL_IDS | |
| } | |
| MODEL_INDEX = {m: i for i, m in enumerate(MODEL_IDS)} | |
| TOKENIZER = AutoTokenizer.from_pretrained( | |
| MODEL_MAPPING["70M"], | |
| pad_token="<pad>", | |
| mask_token="<mask>", | |
| eos_token="<eos>", | |
| padding_side="left", | |
| max_position_embeddings=512, | |
| model_max_length=512, | |
| ) | |
| MODELS = { | |
| model_name: AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| max_position_embeddings=512, | |
| device_map=DEVICE, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| for model_name, model_id in MODEL_MAPPING.items() | |
| } | |
| DOMAINS = [ | |
| "Auto", | |
| "Asset management marketing", | |
| "Annual report", | |
| "Corporate action", | |
| "Equity research", | |
| "Fund fact sheet", | |
| "Kiid", | |
| "Life insurance", | |
| "Regulatory", | |
| "General", | |
| ] | |
| DOMAIN_MAPPING = { | |
| "Auto": None, | |
| "Asset management marketing": "am", | |
| "Annual report": "ar", | |
| "Corporate action": "corporateAction", | |
| "Equity research": "equi", | |
| "Fund fact sheet": "ffs", | |
| "Kiid": "kiid", | |
| "Life insurance": "lifeInsurance", | |
| "Regulatory": "regulatory", | |
| "General": "general", | |
| } | |
| DOMAIN_MAPPING_REVERSED = {v: k for k, v in DOMAIN_MAPPING.items()} | |
| LANG2CODE = { | |
| "English": "en", | |
| "German": "de", | |
| "Spanish": "es", | |
| "French": "fr", | |
| "Italian": "it", | |
| "Dutch": "nl", | |
| "Swedish": "sv", | |
| "Portuguese": "pt", | |
| } | |
| CODE2LANG = {v: k for k, v in LANG2CODE.items()} | |
| LANGUAGES = sorted(LANG2CODE.keys()) | |
| def build_language_token(lang): | |
| return f"<lang_{lang}>" | |
| def build_domain_token(dom): | |
| return f"<dom_{dom}>" | |
| def language_token_to_str(token): | |
| return token[6:-1] | |
| def domain_token_to_str(token): | |
| return token[5:-1] | |
| def format_input(src, tgt_lang, src_lang, domain): | |
| tgt_lang_token = build_language_token(tgt_lang) | |
| prefix = TOKENIZER.eos_token | |
| base_input = f"{prefix}{src}</src>{tgt_lang_token}" | |
| if src_lang is None: | |
| return base_input | |
| else: | |
| src_lang_token = build_language_token(src_lang) | |
| base_input = f"{base_input}{src_lang_token}" | |
| if domain is None: | |
| return base_input | |
| else: | |
| dom_token = build_domain_token(domain) | |
| base_input = f"{base_input}{dom_token}" | |
| return base_input | |
| def translate_with_model(model_name, text, tgt_lang, src_lang, domain): | |
| model = MODELS[model_name] | |
| formatted_text = format_input(text, tgt_lang, src_lang, domain) | |
| inputs = TOKENIZER( | |
| formatted_text, | |
| return_attention_mask=True, | |
| return_tensors="pt", | |
| return_token_type_ids=False, | |
| ) | |
| for k, v in inputs.items(): | |
| inputs[k] = v.to(DEVICE) | |
| src_lang_provided = src_lang is not None | |
| domain_provided = domain is not None | |
| need_format_again = not (src_lang_provided and domain_provided) | |
| past_key_values = DynamicCache() | |
| cache_position = torch.arange( | |
| inputs["input_ids"].size(1), dtype=torch.int64, device=DEVICE | |
| ) | |
| if not src_lang_provided: | |
| # Need to predict src lang | |
| with torch.inference_mode(): | |
| outputs = model( | |
| input_ids=inputs["input_ids"], | |
| attention_mask=inputs["attention_mask"], | |
| use_cache=True, | |
| past_key_values=past_key_values, | |
| cache_position=cache_position, | |
| ) | |
| src_lang_token_id = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(0) | |
| src_lang = language_token_to_str( | |
| TOKENIZER.convert_ids_to_tokens(src_lang_token_id.squeeze().item()) | |
| ) | |
| cache_position = cache_position[-1:] + 1 | |
| attention_mask = inputs["attention_mask"] | |
| attention_mask = torch.cat( | |
| [attention_mask, attention_mask.new_ones((attention_mask.size(0), 1))], | |
| dim=-1, | |
| ) | |
| inputs = {"input_ids": src_lang_token_id, "attention_mask": attention_mask} | |
| if not domain_provided: | |
| # Need to predict domain | |
| with torch.inference_mode(): | |
| outputs = model( | |
| input_ids=inputs["input_ids"], | |
| attention_mask=inputs["attention_mask"], | |
| use_cache=True, | |
| past_key_values=past_key_values, | |
| ) | |
| domain_token_id = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(0) | |
| domain = domain_token_to_str( | |
| TOKENIZER.convert_ids_to_tokens(domain_token_id.squeeze().item()) | |
| ) | |
| cache_position = cache_position[-1:] + 1 | |
| attention_mask = inputs["attention_mask"] | |
| attention_mask = torch.cat( | |
| [attention_mask, attention_mask.new_ones((attention_mask.size(0), 1))], | |
| dim=-1, | |
| ) | |
| inputs = {"input_ids": domain_token_id, "attention_mask": attention_mask} | |
| elif not src_lang_provided: | |
| # in this case, src_lang was not provided, but domain was. | |
| # So we still need to run a forward pass to build the kv cache for the domain token | |
| dom_token = build_domain_token(domain) | |
| # dom_token = "<dom_general>" | |
| domain = domain_token_to_str(dom_token) | |
| domain_token_id = TOKENIZER.convert_tokens_to_ids(dom_token) | |
| inputs["input_ids"] = torch.hstack( | |
| [inputs["input_ids"], torch.tensor([[domain_token_id]], device=DEVICE)] | |
| ) | |
| inputs["attention_mask"] = torch.hstack( | |
| [inputs["attention_mask"], inputs["attention_mask"].new_ones((1, 1))] | |
| ) | |
| cache_position = torch.hstack([cache_position, cache_position[-1:] + 1]) | |
| if need_format_again: | |
| formatted_text = format_input(text, tgt_lang, src_lang, domain) | |
| inputs = TOKENIZER( | |
| formatted_text, | |
| return_attention_mask=True, | |
| return_tensors="pt", | |
| return_token_type_ids=False, | |
| ) | |
| for k, v in inputs.items(): | |
| inputs[k] = v.to(DEVICE) | |
| domain_token_pos = inputs["input_ids"].size(1) - 1 | |
| src_lang_token_pos = domain_token_pos - 1 | |
| _tgt_lang_token_pos = src_lang_token_pos - 1 | |
| streamer = TextIteratorStreamer(TOKENIZER, skip_prompt=True) | |
| generation_kwargs = dict( | |
| input_ids=inputs["input_ids"], | |
| attention_mask=inputs["attention_mask"], | |
| num_beams=1, | |
| max_new_tokens=500, | |
| past_key_values=past_key_values, | |
| streamer=streamer, | |
| eos_token_id=TOKENIZER.eos_token_id, | |
| ) | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| generated_translation = "" | |
| for new_text in streamer: | |
| generated_translation += new_text.replace(TOKENIZER.eos_token, "") | |
| yield { | |
| "model": model_name, | |
| "source_lang": CODE2LANG[src_lang], | |
| "domain": DOMAIN_MAPPING_REVERSED[domain], | |
| "translation": generated_translation, | |
| } | |
| def translate_with_all_models(selected_models, text, tgt_lang, src_lang, domain): | |
| tgt_lang = LANG2CODE[tgt_lang] | |
| src_lang = None if src_lang == "Auto" else LANG2CODE.get(src_lang) | |
| domain = DOMAIN_MAPPING[domain] | |
| outputs = [None] * (3 * len(MODEL_IDS)) | |
| outputs = list( | |
| itertools.chain.from_iterable( | |
| ( | |
| ["Processing..."] * 3 | |
| if model_id in selected_models | |
| else ["This model is disabled"] * 3 | |
| ) | |
| for model_id in MODEL_IDS | |
| ) | |
| ) | |
| for model_id in selected_models: | |
| i = MODEL_INDEX[model_id] | |
| for model_output in translate_with_model( | |
| model_id, text, tgt_lang, src_lang, domain | |
| ): | |
| outputs[i * 3] = model_output["translation"] | |
| outputs[i * 3 + 1] = model_output["source_lang"] | |
| outputs[i * 3 + 2] = model_output["domain"] | |
| yield outputs | |
| with gr.Blocks() as demo: | |
| with gr.Row(variant="default"): | |
| title = "🌐 Multilingual Multidomain Financial Translator" | |
| description = """<p>Specialized Translation for Financial Texts across 8 Languages and 9 Domains</p>""" | |
| gr.HTML(f"<h1>{title}</h1>\n<p>{description}</p>") | |
| with gr.Row(variant="panel"): | |
| with gr.Column(variant="default"): | |
| selected_models = gr.CheckboxGroup( | |
| choices=MODEL_IDS, | |
| value=MODEL_IDS, | |
| type="value", | |
| label="Models", | |
| container=True, | |
| ) | |
| source_text = gr.Textbox(lines=3, label="Source sentence") | |
| with gr.Column(variant="default"): | |
| source_language = gr.Dropdown( | |
| LANGUAGES + ["Auto"], value="Auto", label="Source language" | |
| ) | |
| target_language = gr.Dropdown( | |
| LANGUAGES, value="French", label="Target language" | |
| ) | |
| with gr.Column(variant="default"): | |
| domain = gr.Radio(DOMAINS, value="Auto", label="Domain") | |
| with gr.Row(): | |
| translate_btn = gr.Button("Translate", variant="primary") | |
| with gr.Row(variant="panel"): | |
| outputs = {} | |
| for model_id in MODEL_IDS: | |
| with gr.Tab(model_id): | |
| outputs[model_id] = { | |
| "translation": gr.Textbox( | |
| lines=2, label="Translation", container=True | |
| ), | |
| "source_lang": gr.Textbox( | |
| label="Predicted source language", | |
| info='This is the predicted source language, if "Auto" is selected.', | |
| container=True, | |
| ), | |
| "domain": gr.Textbox( | |
| label="Predicted domain", | |
| info='This is the predicted domain, if "Auto" is checked.', | |
| container=True, | |
| ), | |
| } | |
| gr.HTML( | |
| f"<p>Model: <a href='https://huggingface.co/LinguaCustodia/multilingual-multidomain-fin-mt-{model_id}' target='_blank'>LinguaCustodia/multilingual-multidomain-fin-mt-{model_id}</a></p>" | |
| ) | |
| with gr.Row(variant="panel"): | |
| gr.HTML( | |
| """<p><strong>Please cite this work as:</strong>\n\n<pre>@inproceedings{DBLP:conf/wmt/CaillautNQLB24, | |
| author = {Ga{\"{e}}tan Caillaut and | |
| Mariam Nakhl{\'{e}} and | |
| Raheel Qader and | |
| Jingshu Liu and | |
| Jean{-}Gabriel Barthelemy}, | |
| title = {Scaling Laws of Decoder-Only Models on the Multilingual Machine Translation Task}, | |
| booktitle = {{WMT}}, | |
| pages = {1318--1331}, | |
| publisher = {Association for Computational Linguistics}, | |
| year = {2024} | |
| }</pre></p>""" | |
| ) | |
| translate_btn.click( | |
| fn=translate_with_all_models, | |
| inputs=[selected_models, source_text, target_language, source_language, domain], | |
| outputs=list( | |
| itertools.chain.from_iterable( | |
| [outputs[model_id][k] for k in ("translation", "source_lang", "domain")] | |
| for model_id in MODEL_IDS | |
| ) | |
| ), | |
| ) | |
| demo.launch() | |