Spaces:
Runtime error
Runtime error
| """Demo gradio app for some text/query augmentation.""" | |
| from __future__ import annotations | |
| import functools | |
| from typing import Any | |
| from typing import Callable | |
| from typing import Mapping | |
| from typing import Sequence | |
| import attr | |
| import environ | |
| import fasttext # not working with python3.9 | |
| import gradio as gr | |
| from transformers.pipelines import pipeline | |
| from transformers.pipelines.base import Pipeline | |
| from transformers.pipelines.token_classification import AggregationStrategy | |
| def compose(*functions) -> Callable: | |
| """ | |
| Compose functions. | |
| Args: | |
| functions: functions to compose. | |
| Returns: | |
| Composed functions. | |
| """ | |
| def apply(f, g): | |
| return lambda x: f(g(x)) | |
| return functools.reduce(apply, functions[::-1], lambda x: x) | |
| def mapped(fn) -> Callable: | |
| """ | |
| Decorator to apply map/filter to a function | |
| """ | |
| def inner(func): | |
| partial_fn = functools.partial(fn, func) | |
| def wrapper(*args, **kwargs): | |
| return partial_fn(*args, **kwargs) | |
| return wrapper | |
| return inner | |
| class Prediction: | |
| """Dataclass to store prediction results.""" | |
| label: str | |
| score: float | |
| class Models: | |
| identification: Predictor | |
| translation: Predictor | |
| classification: Predictor | |
| ner: Predictor | |
| recipe: Predictor | |
| class Predictor: | |
| load_fn: Callable | |
| predict_fn: Callable = attr.field(default=lambda model, query: model(query)) | |
| model: Any = attr.field(init=False) | |
| def __attrs_post_init__(self): | |
| object.__setattr__(self, "model", self.load_fn()) | |
| def __call__(self, *args: Any, **kwds: Any) -> Any: | |
| return self.predict_fn(self.model, *args, **kwds) | |
| class AppConfig: | |
| class Identification: | |
| """Identification model configuration.""" | |
| model = environ.var(default="./models/lid.176.ftz") | |
| max_results = environ.var(default=3, converter=int) | |
| class Translation: | |
| """Translation models configuration.""" | |
| model = environ.var(default="t5-small") | |
| sources = environ.var(default="de,fr") | |
| target = environ.var(default="en") | |
| class Classification: | |
| """Classification model configuration.""" | |
| model = environ.var(default="typeform/distilbert-base-uncased-mnli") | |
| max_results = environ.var(default=5, converter=int) | |
| class NER: | |
| general = environ.var( | |
| default="asahi417/tner-xlm-roberta-large-uncased-wnut2017", | |
| ) | |
| recipe = environ.var(default="adamlin/recipe-tag-model") | |
| identification: Identification = environ.group(Identification) | |
| translation: Translation = environ.group(Translation) | |
| classification: Classification = environ.group(Classification) | |
| ner: NER = environ.group(NER) | |
| def predict( | |
| models: Models, | |
| query: str, | |
| categories: Sequence[str], | |
| supported_languages: tuple[str, ...] = ("fr", "de"), | |
| ) -> tuple[ | |
| Mapping[str, float], | |
| Mapping[str, float], | |
| str, | |
| Sequence[tuple[str, str | None]], | |
| Sequence[tuple[str, str | None]], | |
| ]: | |
| """Predict from a textual query: | |
| - the language | |
| - classify as a recipe or not | |
| - extract the recipe | |
| """ | |
| def predict_lang(query) -> Mapping[str, float]: | |
| def predict_fn(query) -> Sequence[Prediction]: | |
| return tuple( | |
| Prediction(label=label, score=score) | |
| for label, score in zip(*models.identification(query, k=176)) | |
| ) | |
| def format_label(prediction: Prediction) -> Prediction: | |
| return attr.evolve( | |
| prediction, | |
| label=prediction.label.replace("__label__", ""), | |
| ) | |
| def filter_labels(prediction: Prediction) -> bool: | |
| return prediction.label in supported_languages + ("en",) | |
| def format_output(predictions: Sequence[Prediction]) -> dict: | |
| return {pred.label: pred.score for pred in predictions} | |
| apply_fn = compose( | |
| predict_fn, | |
| format_label, | |
| functools.partial(filter, filter_labels), | |
| format_output, | |
| ) | |
| return apply_fn(query) | |
| def translate_query(query: str, languages: Mapping[str, float]) -> str: | |
| def predicted_language() -> str: | |
| return max(languages.items(), key=lambda lang: lang[1])[0] | |
| def translate(query): | |
| lang = predicted_language() | |
| if lang in supported_languages: | |
| output = models.translation(query, lang)[0]["translation_text"] | |
| else: | |
| output = query | |
| return output | |
| return translate(query) | |
| def classify_query(query, categories) -> Mapping[str, float]: | |
| predictions = models.classification(query, categories) | |
| return dict(zip(predictions["labels"], predictions["scores"])) | |
| def extract_entities( | |
| predict_fn: Callable, | |
| query: str, | |
| ) -> Sequence[tuple[str, str | None]]: | |
| predictions = predict_fn(query) | |
| if len(predictions) == 0: | |
| return [(query, None)] | |
| else: | |
| return [ | |
| (pred["word"], pred.get("entity_group", pred.get("entity", None))) | |
| for pred in predictions | |
| ] | |
| languages = predict_lang(query) | |
| translation = translate_query(query, languages) | |
| classifications = classify_query(translation, categories) | |
| general_entities = extract_entities(models.ner, query) | |
| recipe_entities = extract_entities(models.recipe, translation) | |
| return languages, classifications, translation, general_entities, recipe_entities | |
| def main(): | |
| cfg: AppConfig = AppConfig.from_environ() | |
| def load_translation_models( | |
| sources: Sequence[str], | |
| target: str, | |
| models: Sequence[str], | |
| ) -> Pipeline: | |
| result = { | |
| src: pipeline(f"translation_{src}_to_{target}", models) | |
| for src, models in zip(sources, models) | |
| } | |
| return result | |
| def extract_commas_separated_values(value: str) -> Sequence[str]: | |
| return tuple(filter(None, value.split(","))) | |
| models = Models( | |
| identification=Predictor( | |
| load_fn=lambda: fasttext.load_model(cfg.identification.model), | |
| predict_fn=lambda model, query, k: model.predict(query, k=k), | |
| ), | |
| translation=Predictor( | |
| load_fn=functools.partial( | |
| load_translation_models, | |
| sources=extract_commas_separated_values(cfg.translation.sources), | |
| target=cfg.translation.target, | |
| models=["Helsinki-NLP/opus-mt-de-en", "Helsinki-NLP/opus-mt-fr-en"], | |
| ), | |
| predict_fn=lambda models, query, src: models[src](query), | |
| ), | |
| classification=Predictor( | |
| load_fn=lambda: pipeline( | |
| "zero-shot-classification", | |
| model=cfg.classification.model, | |
| ), | |
| predict_fn=lambda model, query, categories: model(query, categories), | |
| ), | |
| ner=Predictor( | |
| load_fn=lambda: pipeline( | |
| "ner", | |
| model=cfg.ner.general, | |
| aggregation_strategy=AggregationStrategy.SIMPLE, | |
| ), | |
| ), | |
| recipe=Predictor( | |
| load_fn=lambda: pipeline("ner", model=cfg.ner.recipe), | |
| ), | |
| ) | |
| iface = gr.Interface( | |
| fn=lambda query, categories: predict( | |
| models, | |
| query.strip(), | |
| extract_commas_separated_values(categories), | |
| ), | |
| examples=[["gateau au chocolat paris"], ["Newyork LA flight"]], | |
| inputs=[ | |
| gr.inputs.Textbox(label="Query"), | |
| gr.inputs.Textbox( | |
| label="categories (commas separated and in english)", | |
| default="cooking and recipe,traveling,location,information,buy or sell", | |
| ), | |
| ], | |
| outputs=[ | |
| gr.outputs.Label( | |
| num_top_classes=cfg.identification.max_results, | |
| type="auto", | |
| label="Language identification", | |
| ), | |
| gr.outputs.Label( | |
| num_top_classes=cfg.classification.max_results, | |
| type="auto", | |
| label="Predicted categories", | |
| ), | |
| gr.outputs.Textbox( | |
| label="English query", | |
| type="auto", | |
| ), | |
| gr.outputs.HighlightedText(label="NER generic"), | |
| gr.outputs.HighlightedText(label="NER Recipes"), | |
| ], | |
| interpretation="default", | |
| ) | |
| iface.launch(debug=True) | |
| if __name__ == "__main__": | |
| main() | |