Update app.py
Browse files
app.py
CHANGED
|
@@ -17,7 +17,7 @@ from transformers import (
|
|
| 17 |
pipeline as hf_pipeline
|
| 18 |
)
|
| 19 |
|
| 20 |
-
# ββ 1) Model setup
|
| 21 |
|
| 22 |
MODEL = "facebook/hf-seamless-m4t-medium"
|
| 23 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
@@ -31,7 +31,6 @@ if device == "cuda":
|
|
| 31 |
m4t_model.eval()
|
| 32 |
|
| 33 |
def translate_m4t(text: str, src_iso3: str, tgt_iso3: str, auto_detect=False) -> str:
|
| 34 |
-
"""Single-string translation (used for initial autoβdetect β English)."""
|
| 35 |
src = None if auto_detect else src_iso3
|
| 36 |
inputs = processor(text=text, src_lang=src, return_tensors="pt").to(device)
|
| 37 |
tokens = m4t_model.generate(**inputs, tgt_lang=tgt_iso3)
|
|
@@ -40,7 +39,6 @@ def translate_m4t(text: str, src_iso3: str, tgt_iso3: str, auto_detect=False) ->
|
|
| 40 |
def translate_m4t_batch(
|
| 41 |
texts: List[str], src_iso3: str, tgt_iso3: str, auto_detect=False
|
| 42 |
) -> List[str]:
|
| 43 |
-
"""Batchβmode translation: one generate() for many inputs."""
|
| 44 |
src = None if auto_detect else src_iso3
|
| 45 |
inputs = processor(
|
| 46 |
text=texts, src_lang=src, return_tensors="pt", padding=True
|
|
@@ -53,9 +51,15 @@ def translate_m4t_batch(
|
|
| 53 |
)
|
| 54 |
return processor.batch_decode(tokens, skip_special_tokens=True)
|
| 55 |
|
| 56 |
-
# ββ 2) NER pipeline βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 57 |
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
# ββ 3) CACHING helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 61 |
|
|
@@ -98,6 +102,7 @@ def wiki_summary_cache(name: str) -> str:
|
|
| 98 |
except:
|
| 99 |
return "No summary available."
|
| 100 |
|
|
|
|
| 101 |
# ββ 4) Per-entity worker ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 102 |
|
| 103 |
def process_entity(ent) -> dict:
|
|
@@ -135,11 +140,11 @@ def process_entity(ent) -> dict:
|
|
| 135 |
|
| 136 |
def get_context(
|
| 137 |
text: str,
|
| 138 |
-
source_lang: str,
|
| 139 |
-
output_lang: str,
|
| 140 |
auto_detect: bool
|
| 141 |
):
|
| 142 |
-
# a) Ensure
|
| 143 |
if auto_detect or source_lang != "eng":
|
| 144 |
en = translate_m4t(text, source_lang, "eng", auto_detect=auto_detect)
|
| 145 |
else:
|
|
@@ -156,23 +161,22 @@ def get_context(
|
|
| 156 |
seen.add(w)
|
| 157 |
unique_ents.append(ent)
|
| 158 |
|
| 159 |
-
# c)
|
| 160 |
entities = []
|
| 161 |
with ThreadPoolExecutor(max_workers=8) as exe:
|
| 162 |
futures = [exe.submit(process_entity, ent) for ent in unique_ents]
|
| 163 |
for fut in futures:
|
| 164 |
entities.append(fut.result())
|
| 165 |
|
| 166 |
-
# d) Batch-translate
|
| 167 |
if output_lang != "eng":
|
| 168 |
to_translate = []
|
| 169 |
-
translations_info = []
|
| 170 |
|
| 171 |
for i, e in enumerate(entities):
|
| 172 |
if e["type"] == "wiki":
|
| 173 |
translations_info.append(("summary", i))
|
| 174 |
to_translate.append(e["summary"])
|
| 175 |
-
|
| 176 |
elif e["type"] == "location":
|
| 177 |
for j, r in enumerate(e["restaurants"]):
|
| 178 |
translations_info.append(("restaurant", i, j))
|
|
@@ -181,10 +185,8 @@ def get_context(
|
|
| 181 |
translations_info.append(("attraction", i, j))
|
| 182 |
to_translate.append(a["name"])
|
| 183 |
|
| 184 |
-
# single batched call
|
| 185 |
translated = translate_m4t_batch(to_translate, "eng", output_lang)
|
| 186 |
|
| 187 |
-
# redistribute
|
| 188 |
for txt, info in zip(translated, translations_info):
|
| 189 |
kind = info[0]
|
| 190 |
if kind == "summary":
|
|
@@ -200,7 +202,7 @@ def get_context(
|
|
| 200 |
return {"entities": entities}
|
| 201 |
|
| 202 |
|
| 203 |
-
# ββ 6) Gradio interface
|
| 204 |
|
| 205 |
iface = gr.Interface(
|
| 206 |
fn=get_context,
|
|
@@ -213,7 +215,7 @@ iface = gr.Interface(
|
|
| 213 |
outputs="json",
|
| 214 |
title="iVoice Context-Aware",
|
| 215 |
description="Returns only the detected entities and their related info."
|
| 216 |
-
).queue(
|
| 217 |
|
| 218 |
if __name__ == "__main__":
|
| 219 |
iface.launch(
|
|
|
|
| 17 |
pipeline as hf_pipeline
|
| 18 |
)
|
| 19 |
|
| 20 |
+
# ββ 1) Model setup ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 21 |
|
| 22 |
MODEL = "facebook/hf-seamless-m4t-medium"
|
| 23 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 31 |
m4t_model.eval()
|
| 32 |
|
| 33 |
def translate_m4t(text: str, src_iso3: str, tgt_iso3: str, auto_detect=False) -> str:
|
|
|
|
| 34 |
src = None if auto_detect else src_iso3
|
| 35 |
inputs = processor(text=text, src_lang=src, return_tensors="pt").to(device)
|
| 36 |
tokens = m4t_model.generate(**inputs, tgt_lang=tgt_iso3)
|
|
|
|
| 39 |
def translate_m4t_batch(
|
| 40 |
texts: List[str], src_iso3: str, tgt_iso3: str, auto_detect=False
|
| 41 |
) -> List[str]:
|
|
|
|
| 42 |
src = None if auto_detect else src_iso3
|
| 43 |
inputs = processor(
|
| 44 |
text=texts, src_lang=src, return_tensors="pt", padding=True
|
|
|
|
| 51 |
)
|
| 52 |
return processor.batch_decode(tokens, skip_special_tokens=True)
|
| 53 |
|
|
|
|
| 54 |
|
| 55 |
+
# ββ 2) NER pipeline (updated for deprecation) ββββββββββββββββββββββββββββββββ
|
| 56 |
+
|
| 57 |
+
ner = hf_pipeline(
|
| 58 |
+
"ner",
|
| 59 |
+
model="dslim/bert-base-NER-uncased",
|
| 60 |
+
aggregation_strategy="simple"
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
|
| 64 |
# ββ 3) CACHING helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 65 |
|
|
|
|
| 102 |
except:
|
| 103 |
return "No summary available."
|
| 104 |
|
| 105 |
+
|
| 106 |
# ββ 4) Per-entity worker ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 107 |
|
| 108 |
def process_entity(ent) -> dict:
|
|
|
|
| 140 |
|
| 141 |
def get_context(
|
| 142 |
text: str,
|
| 143 |
+
source_lang: str,
|
| 144 |
+
output_lang: str,
|
| 145 |
auto_detect: bool
|
| 146 |
):
|
| 147 |
+
# a) Ensure English for NER
|
| 148 |
if auto_detect or source_lang != "eng":
|
| 149 |
en = translate_m4t(text, source_lang, "eng", auto_detect=auto_detect)
|
| 150 |
else:
|
|
|
|
| 161 |
seen.add(w)
|
| 162 |
unique_ents.append(ent)
|
| 163 |
|
| 164 |
+
# c) Parallel I/O
|
| 165 |
entities = []
|
| 166 |
with ThreadPoolExecutor(max_workers=8) as exe:
|
| 167 |
futures = [exe.submit(process_entity, ent) for ent in unique_ents]
|
| 168 |
for fut in futures:
|
| 169 |
entities.append(fut.result())
|
| 170 |
|
| 171 |
+
# d) Batch-translate non-English fields
|
| 172 |
if output_lang != "eng":
|
| 173 |
to_translate = []
|
| 174 |
+
translations_info = []
|
| 175 |
|
| 176 |
for i, e in enumerate(entities):
|
| 177 |
if e["type"] == "wiki":
|
| 178 |
translations_info.append(("summary", i))
|
| 179 |
to_translate.append(e["summary"])
|
|
|
|
| 180 |
elif e["type"] == "location":
|
| 181 |
for j, r in enumerate(e["restaurants"]):
|
| 182 |
translations_info.append(("restaurant", i, j))
|
|
|
|
| 185 |
translations_info.append(("attraction", i, j))
|
| 186 |
to_translate.append(a["name"])
|
| 187 |
|
|
|
|
| 188 |
translated = translate_m4t_batch(to_translate, "eng", output_lang)
|
| 189 |
|
|
|
|
| 190 |
for txt, info in zip(translated, translations_info):
|
| 191 |
kind = info[0]
|
| 192 |
if kind == "summary":
|
|
|
|
| 202 |
return {"entities": entities}
|
| 203 |
|
| 204 |
|
| 205 |
+
# ββ 6) Gradio interface βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 206 |
|
| 207 |
iface = gr.Interface(
|
| 208 |
fn=get_context,
|
|
|
|
| 215 |
outputs="json",
|
| 216 |
title="iVoice Context-Aware",
|
| 217 |
description="Returns only the detected entities and their related info."
|
| 218 |
+
).queue() # β removed unsupported kwargs
|
| 219 |
|
| 220 |
if __name__ == "__main__":
|
| 221 |
iface.launch(
|