gcaillaut commited on
Commit
e1b65d6
·
1 Parent(s): a0f8e98

update app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -67
app.py CHANGED
@@ -1,9 +1,51 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
- import os
 
4
 
5
- LANGUAGES = ["en", "de", "es", "fr", "it", "nl", "sv", "pt"]
6
- DOMAINS = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  "Asset management": "am",
8
  "Annual report": "ar",
9
  "Corporate action": "corporateAction",
@@ -14,87 +56,185 @@ DOMAINS = {
14
  "Regulatory": "regulatory",
15
  "General": "general",
16
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- # Helper functions
19
  def language_token(lang):
20
  return f"<lang_{lang}>"
21
 
 
22
  def domain_token(dom):
23
  return f"<dom_{dom}>"
24
 
 
 
 
 
 
 
 
 
 
25
  def format_input(src, tgt_lang, src_lang, domain):
26
- assert tgt_lang in LANGUAGES
27
  tgt_lang_token = language_token(tgt_lang)
28
- # Prefix the input with <eos>
29
- base_input = f"<eos>{src}</src>{tgt_lang_token}"
30
- if src_lang:
31
- assert src_lang in LANGUAGES
 
 
 
32
  src_lang_token = language_token(src_lang)
33
  base_input = f"{base_input}{src_lang_token}"
34
- if domain:
35
- domain = DOMAINS.get(domain, "general")
 
 
36
  dom_token = domain_token(domain)
37
  base_input = f"{base_input}{dom_token}"
 
38
  return base_input
39
 
40
- # Initialize model and tokenizer globally to avoid reloading
41
- model_id = "LinguaCustodia/multilingual-multidomain-fin-mt-70M"
42
- tokenizer = AutoTokenizer.from_pretrained(model_id)
43
- model = AutoModelForCausalLM.from_pretrained(model_id)
44
-
45
- def translate(text, source_lang, target_lang, domain):
46
- if not text:
47
- return ""
48
-
49
- src_lang_code = language_map.get(source_lang)
50
- tgt_lang_code = language_map.get(target_lang)
51
-
52
- formatted_sentence = format_input(text, tgt_lang_code, src_lang_code, domain)
53
- inputs = tokenizer(formatted_sentence, return_tensors="pt", return_token_type_ids=False)
54
-
55
- outputs = model.generate(**inputs, max_new_tokens=256)
56
-
57
- input_size = inputs["input_ids"].size(1)
58
- translated_sentence = tokenizer.decode(
59
- outputs[0, input_size:], skip_special_tokens=True
 
 
 
 
 
60
  )
61
-
62
- return translated_sentence
63
-
64
- language_map = {
65
- "English": "en",
66
- "German": "de",
67
- "Spanish": "es",
68
- "French": "fr",
69
- "Italian": "it",
70
- "Dutch": "nl",
71
- "Swedish": "sv",
72
- "Portuguese": "pt"
73
- }
74
 
75
- title = "🌐 Multilingual Multidomain Financial Translator 🌐"
76
- description = """<p><center>Specialized Translation for Financial Documents across 8 Languages and 9 Domains</center></p>"""
77
- article = """<p style='text-align: center'>Model: <a href='https://huggingface.co/LinguaCustodia/multilingual-multidomain-fin-mt-70M' target='_blank'>LinguaCustodia/multilingual-multidomain-fin-mt-70M</a></p>"""
78
 
79
- examples = [
80
- ["Nous avons enregistré une croissance du chiffre d'affaires de 5,7% au troisième trimestre.", "French", "English", "Annual report"],
81
- ["The funds under management increased by €2.3 billion during the fiscal year.", "English", "Spanish", "Asset management"],
82
- ["Der Aufsichtsrat hat den Jahresabschluss geprüft und genehmigt.", "German", "French", "Regulatory"]
83
- ]
84
 
85
- demo = gr.Interface(
86
- fn=translate,
87
- title=title,
88
- description=description,
89
- article=article,
90
- inputs=[
91
- gr.Textbox(lines=5, placeholder="Enter text to translate (maximum 5 lines)", label="Input Text"),
92
- gr.Dropdown(choices=list(language_map.keys()), value="French", label="Source Language"),
93
- gr.Dropdown(choices=list(language_map.keys()), value="English", label="Target Language"),
94
- gr.Dropdown(choices=list(DOMAINS.keys()), value="General", label="Financial Domain"),
95
- ],
96
- outputs=gr.Textbox(label="Translation"),
97
- examples=examples
98
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- demo.launch()
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+ import itertools
5
 
6
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
7
+ MODEL_IDS = [
8
+ "70M",
9
+ # "160M",
10
+ ]
11
+ MODEL_MAPPING = {
12
+ model_id: f"LinguaCustodia/multilingual-multidomain-fin-mt-{model_id}"
13
+ for model_id in MODEL_IDS
14
+ }
15
+ TOKENIZER = AutoTokenizer.from_pretrained(
16
+ MODEL_MAPPING["70M"],
17
+ pad_token="<pad>",
18
+ mask_token="<mask>",
19
+ eos_token="<eos>",
20
+ padding_side="left",
21
+ max_position_embeddings=512,
22
+ model_max_length=512,
23
+ )
24
+ MODELS = {
25
+ model_name: AutoModelForCausalLM.from_pretrained(
26
+ model_id,
27
+ max_position_embeddings=512,
28
+ device_map=DEVICE,
29
+ torch_dtype=torch.bfloat16,
30
+ )
31
+ for model_name, model_id in MODEL_MAPPING.items()
32
+ }
33
+
34
+ DOMAINS = [
35
+ "Auto",
36
+ "Asset manangement",
37
+ "Annual report",
38
+ "Corporate action",
39
+ "Equity research",
40
+ "Fund fact sheet",
41
+ "Kiid",
42
+ "Life insurance",
43
+ "Regulatory",
44
+ "General",
45
+ ]
46
+
47
+ DOMAIN_MAPPING = {
48
+ "Auto": None,
49
  "Asset management": "am",
50
  "Annual report": "ar",
51
  "Corporate action": "corporateAction",
 
56
  "Regulatory": "regulatory",
57
  "General": "general",
58
  }
59
+ DOMAIN_MAPPING_REVERSED = {v: k for k, v in DOMAIN_MAPPING.items()}
60
+
61
+ LANG2CODE = {
62
+ "English": "en",
63
+ "German": "de",
64
+ "Spanish": "es",
65
+ "French": "fr",
66
+ "Italian": "it",
67
+ "Dutch": "nl",
68
+ "Swedish": "sv",
69
+ "Portuguese": "pt",
70
+ }
71
+ CODE2LANG = {v: k for k, v in LANG2CODE.items()}
72
+ LANGUAGES = sorted(LANG2CODE.keys())
73
+
74
 
 
75
  def language_token(lang):
76
  return f"<lang_{lang}>"
77
 
78
+
79
  def domain_token(dom):
80
  return f"<dom_{dom}>"
81
 
82
+
83
+ def language_token_to_str(token):
84
+ return token[6:-1]
85
+
86
+
87
+ def domain_token_to_str(token):
88
+ return token[5:-1]
89
+
90
+
91
  def format_input(src, tgt_lang, src_lang, domain):
 
92
  tgt_lang_token = language_token(tgt_lang)
93
+
94
+ prefix = TOKENIZER.eos_token
95
+
96
+ base_input = f"{prefix}{src}</src>{tgt_lang_token}"
97
+ if src_lang is None:
98
+ return base_input
99
+ else:
100
  src_lang_token = language_token(src_lang)
101
  base_input = f"{base_input}{src_lang_token}"
102
+
103
+ if domain is None:
104
+ return base_input
105
+ else:
106
  dom_token = domain_token(domain)
107
  base_input = f"{base_input}{dom_token}"
108
+
109
  return base_input
110
 
111
+
112
+ def translate_with_model(model_name, text, tgt_lang, src_lang, domain):
113
+ model = MODELS[model_name]
114
+ formatted_text = format_input(text, tgt_lang, src_lang, domain)
115
+
116
+ inputs = TOKENIZER(formatted_text, return_tensors="pt", return_token_type_ids=False)
117
+ for k, v in inputs.items():
118
+ inputs[k] = v.to(DEVICE)
119
+
120
+ if src_lang is None:
121
+ domain_token_pos = inputs["input_ids"].size(1) + 1
122
+ elif domain is None:
123
+ domain_token_pos = inputs["input_ids"].size(1)
124
+ else:
125
+ domain_token_pos = inputs["input_ids"].size(1) - 1
126
+ src_lang_token_pos = domain_token_pos - 1
127
+ _tgt_lang_token_pos = src_lang_token_pos - 1
128
+
129
+ outputs = model.generate(
130
+ **inputs,
131
+ num_beams=5,
132
+ length_penalty=0.65,
133
+ max_new_tokens=500,
134
+ pad_token_id=TOKENIZER.pad_token_id,
135
+ eos_token_id=TOKENIZER.eos_token_id,
136
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ generated_translation = TOKENIZER.decode(
139
+ outputs[0, domain_token_pos + 1 :], skip_special_tokens=True
140
+ )
141
 
142
+ source_language_token = TOKENIZER.convert_ids_to_tokens(
143
+ outputs[0, src_lang_token_pos].item()
144
+ )
145
+ domain_token = TOKENIZER.convert_ids_to_tokens(outputs[0, domain_token_pos].item())
 
146
 
147
+ return {
148
+ "model": model_name,
149
+ "source_lang": CODE2LANG[language_token_to_str(source_language_token)],
150
+ "domain": DOMAIN_MAPPING_REVERSED[domain_token_to_str(domain_token)],
151
+ "translation": generated_translation,
152
+ }
153
+
154
+
155
+ def translate_with_all_models(text, tgt_lang, src_lang, domain):
156
+ tgt_lang = LANG2CODE[tgt_lang]
157
+ src_lang = None if src_lang == "Auto" else LANG2CODE.get(src_lang)
158
+ domain = DOMAIN_MAPPING[domain]
159
+
160
+ res = {
161
+ model_name: translate_with_model(model_name, text, tgt_lang, src_lang, domain)
162
+ for model_name in MODEL_IDS
163
+ }
164
+ return list(
165
+ itertools.chain.from_iterable(
166
+ [res[model_id][k] for k in ("translation", "source_lang", "domain")]
167
+ for model_id in MODEL_IDS
168
+ )
169
+ )
170
+
171
+
172
+ with gr.Blocks() as demo:
173
+ with gr.Row(variant="default"):
174
+ title = "🌐 Multilingual Multidomain Financial Translator"
175
+ description = """<p>Specialized Translation for Financial Documents across 8 Languages and 9 Domains</p>"""
176
+ gr.HTML(f"<h1>{title}</h1>\n<p>{description}</p>")
177
+
178
+ with gr.Row(variant="panel"):
179
+ with gr.Column(variant="default"):
180
+ source_text = gr.Textbox(lines=3, label="Source sentence")
181
+ with gr.Column(variant="default"):
182
+ target_language = gr.Dropdown(
183
+ LANGUAGES, value="French", label="Target language"
184
+ )
185
+ source_language = gr.Dropdown(
186
+ LANGUAGES + ["Auto"], value="Auto", label="Source language"
187
+ )
188
+ with gr.Column(variant="default"):
189
+ domain = gr.Radio(DOMAINS, value="Auto", label="Domain")
190
+
191
+ with gr.Row():
192
+ translate_btn = gr.Button("Translate", variant="primary")
193
+
194
+ with gr.Row(variant="panel"):
195
+ outputs = {}
196
+ for model_id in MODEL_IDS:
197
+ with gr.Tab(model_id):
198
+ outputs[model_id] = {
199
+ "translation": gr.Textbox(lines=2, label="Translation"),
200
+ "source_lang": gr.Textbox(
201
+ label="Predicted source language",
202
+ info='This is the predicted source language, if "Auto" is selected.',
203
+ ),
204
+ "domain": gr.Textbox(
205
+ label="Predicted domain",
206
+ info='This is the predicted domain, if "Auto" is checked.',
207
+ ),
208
+ }
209
+ gr.HTML(
210
+ f"<p>Model: <a href='https://huggingface.co/LinguaCustodia/multilingual-multidomain-fin-mt-{model_id}' target='_blank'>LinguaCustodia/multilingual-multidomain-fin-mt-{model_id}</a></p>"
211
+ )
212
+
213
+ with gr.Row(variant="panel"):
214
+ gr.HTML(
215
+ """<p><strong>Please cite this work as:</strong>\n\n<pre>@inproceedings{DBLP:conf/wmt/CaillautNQLB24,
216
+ author = {Ga{\"{e}}tan Caillaut and
217
+ Mariam Nakhl{\'{e}} and
218
+ Raheel Qader and
219
+ Jingshu Liu and
220
+ Jean{-}Gabriel Barthelemy},
221
+ title = {Scaling Laws of Decoder-Only Models on the Multilingual Machine Translation Task},
222
+ booktitle = {{WMT}},
223
+ pages = {1318--1331},
224
+ publisher = {Association for Computational Linguistics},
225
+ year = {2024}
226
+ }</pre></p>"""
227
+ )
228
+
229
+ translate_btn.click(
230
+ fn=translate_with_all_models,
231
+ inputs=[source_text, target_language, source_language, domain],
232
+ outputs=list(
233
+ itertools.chain.from_iterable(
234
+ [outputs[model_id][k] for k in ("translation", "source_lang", "domain")]
235
+ for model_id in MODEL_IDS
236
+ )
237
+ ),
238
+ )
239
 
240
+ demo.launch()