Spaces:
Paused
Paused
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import os | |
| LANGUAGES = ["en", "de", "es", "fr", "it", "nl", "sv", "pt"] | |
| DOMAINS = { | |
| "Asset management": "am", | |
| "Annual report": "ar", | |
| "Corporate action": "corporateAction", | |
| "Equity research": "equi", | |
| "Fund fact sheet": "ffs", | |
| "Kiid": "kiid", | |
| "Life insurance": "lifeInsurance", | |
| "Regulatory": "regulatory", | |
| "General": "general", | |
| } | |
| # Helper functions | |
| def language_token(lang): | |
| return f"<lang_{lang}>" | |
| def domain_token(dom): | |
| return f"<dom_{dom}>" | |
| def format_input(src, tgt_lang, src_lang, domain): | |
| assert tgt_lang in LANGUAGES | |
| tgt_lang_token = language_token(tgt_lang) | |
| # Prefix the input with <eos> | |
| base_input = f"<eos>{src}</src>{tgt_lang_token}" | |
| if src_lang: | |
| assert src_lang in LANGUAGES | |
| src_lang_token = language_token(src_lang) | |
| base_input = f"{base_input}{src_lang_token}" | |
| if domain: | |
| domain = DOMAINS.get(domain, "general") | |
| dom_token = domain_token(domain) | |
| base_input = f"{base_input}{dom_token}" | |
| return base_input | |
| # Initialize model and tokenizer globally to avoid reloading | |
| model_id = "LinguaCustodia/multilingual-multidomain-fin-mt-70M" | |
| auth_token = os.environ.get("TOKEN") or True | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=auth_token) | |
| model = AutoModelForCausalLM.from_pretrained(model_id) | |
| def translate(text, source_lang, target_lang, domain): | |
| if not text: | |
| return "" | |
| src_lang_code = language_map.get(source_lang) | |
| tgt_lang_code = language_map.get(target_lang) | |
| formatted_sentence = format_input(text, tgt_lang_code, src_lang_code, domain) | |
| inputs = tokenizer(formatted_sentence, return_tensors="pt", return_token_type_ids=False) | |
| outputs = model.generate(**inputs, max_new_tokens=256) | |
| input_size = inputs["input_ids"].size(1) | |
| translated_sentence = tokenizer.decode( | |
| outputs[0, input_size:], skip_special_tokens=True | |
| ) | |
| return translated_sentence | |
| language_map = { | |
| "English": "en", | |
| "German": "de", | |
| "Spanish": "es", | |
| "French": "fr", | |
| "Italian": "it", | |
| "Dutch": "nl", | |
| "Swedish": "sv", | |
| "Portuguese": "pt" | |
| } | |
| title = "🌐 Multilingual Multidomain Financial Translator 🌐" | |
| description = """<p><center>Specialized Translation for Financial Documents across 8 Languages and 9 Domains</center></p>""" | |
| article = """<p style='text-align: center'>Model: <a href='https://huggingface.co/LinguaCustodia/multilingual-multidomain-fin-mt-70M' target='_blank'>LinguaCustodia/multilingual-multidomain-fin-mt-70M</a></p>""" | |
| examples = [ | |
| ["Nous avons enregistré une croissance du chiffre d'affaires de 5,7% au troisième trimestre.", "French", "English", "Annual report"], | |
| ["The funds under management increased by €2.3 billion during the fiscal year.", "English", "Spanish", "Asset management"], | |
| ["Der Aufsichtsrat hat den Jahresabschluss geprüft und genehmigt.", "German", "French", "Regulatory"] | |
| ] | |
| demo = gr.Interface( | |
| fn=translate, | |
| title=title, | |
| description=description, | |
| article=article, | |
| inputs=[ | |
| gr.Textbox(lines=5, placeholder="Enter text to translate (maximum 5 lines)", label="Input Text"), | |
| gr.Dropdown(choices=list(language_map.keys()), value="French", label="Source Language"), | |
| gr.Dropdown(choices=list(language_map.keys()), value="English", label="Target Language"), | |
| gr.Dropdown(choices=list(DOMAINS.keys()), value="General", label="Financial Domain"), | |
| ], | |
| outputs=gr.Textbox(label="Translation"), | |
| examples=examples | |
| ) | |
| demo.launch(enable_queue=True) |