Spaces:
Runtime error
Runtime error
| """⭐ Text Classification with Optimum and ONNXRuntime | |
| Streamlit application to classify text using multiple models. | |
| Author: | |
| - @ChainYo - https://github.com/ChainYo | |
| """ | |
| import plotly | |
| import plotly.figure_factory as ff | |
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| from pathlib import Path | |
| from time import sleep | |
| from typing import Dict, List, Union | |
| from optimum.onnxruntime import ORTModelForSequenceClassification, ORTOptimizer, ORTQuantizer | |
| from optimum.onnxruntime.configuration import OptimizationConfig, AutoQuantizationConfig | |
| from optimum.pipelines import pipeline as ort_pipeline | |
| from transformers import BertTokenizer, BertForSequenceClassification | |
| from transformers import pipeline as pt_pipeline | |
| from utils import calculate_inference_time | |
| HUB_MODEL_PATH = "yiyanghkust/finbert-tone" | |
| BASE_PATH = Path("models") | |
| ONNX_MODEL_PATH = BASE_PATH.joinpath("model.onnx") | |
| OPTIMIZED_BASE_PATH = BASE_PATH.joinpath("optimized") | |
| OPTIMIZED_MODEL_PATH = OPTIMIZED_BASE_PATH.joinpath("model-optimized.onnx") | |
| QUANTIZED_BASE_PATH = BASE_PATH.joinpath("quantized") | |
| QUANTIZED_MODEL_PATH = QUANTIZED_BASE_PATH.joinpath("model-quantized.onnx") | |
| VAR2LABEL = { | |
| "pt_pipeline": "PyTorch", | |
| "ort_pipeline": "ONNXRuntime", | |
| "ort_optimized_pipeline": "ONNXRuntime (Optimized)", | |
| "ort_quantized_pipeline": "ONNXRuntime (Quantized)", | |
| } | |
| # Check if repositories exist, if not create them | |
| BASE_PATH.mkdir(exist_ok=True) | |
| QUANTIZED_BASE_PATH.mkdir(exist_ok=True) | |
| OPTIMIZED_BASE_PATH.mkdir(exist_ok=True) | |
| def get_timers( | |
| samples: Union[List[str], str], exp_number: int, only_mean: bool = False | |
| ) -> Dict[str, float]: | |
| """ | |
| Calculate inference time for each model for a given sample or list of samples. | |
| Parameters | |
| ---------- | |
| samples : Union[List[str], str] | |
| Sample or list of samples to calculate inference time for. | |
| exp_number : int | |
| Number of experiments to run. | |
| Returns | |
| ------- | |
| Dict[str, float] | |
| Dictionary of inference times for each model for the given samples. | |
| """ | |
| if isinstance(samples, str): | |
| samples = [samples] | |
| timers: Dict[str, float] = {} | |
| for model in VAR2LABEL.keys(): | |
| time_buffer = [] | |
| st.session_state["pipeline"] = load_pipeline(model) | |
| for _ in range(exp_number): | |
| with calculate_inference_time(time_buffer): | |
| st.session_state["pipeline"](samples) | |
| timers[VAR2LABEL[model]] = np.mean(time_buffer) if only_mean else time_buffer | |
| return timers | |
| def get_plot(timers: Dict[str, Union[float, List[float]]]) -> plotly.graph_objs.Figure: | |
| """ | |
| Plot the inference time for each model. | |
| Parameters | |
| ---------- | |
| timers : Dict[str, Union[float, List[float]]] | |
| Dictionary of inference times for each model. | |
| """ | |
| data = pd.DataFrame.from_dict(timers, orient="columns") | |
| colors = ["#84353f", "#b4524b", "#f47e58", "#ffbe67"] | |
| fig = ff.create_distplot( | |
| [data[col] for col in data.columns], data.columns, bin_size=0.001, colors=colors, show_curve=False | |
| ) | |
| fig.update_layout(title_text="Inference Time", xaxis_title="Inference Time (s)", yaxis_title="Number of Samples") | |
| return fig | |
| def load_pipeline(pipeline_name: str) -> None: | |
| """ | |
| Load a pipeline for a given model. | |
| Parameters | |
| ---------- | |
| pipeline_name : str | |
| Name of the pipeline to load. | |
| """ | |
| if pipeline_name == "pt_pipeline": | |
| model = BertForSequenceClassification.from_pretrained(HUB_MODEL_PATH, num_labels=3) | |
| pipeline = pt_pipeline("sentiment-analysis", tokenizer=st.session_state["tokenizer"], model=model) | |
| elif pipeline_name == "ort_pipeline": | |
| model = ORTModelForSequenceClassification.from_pretrained(HUB_MODEL_PATH, from_transformers=True) | |
| if not ONNX_MODEL_PATH.exists(): | |
| model.save_pretrained(ONNX_MODEL_PATH) | |
| pipeline = ort_pipeline("text-classification", tokenizer=st.session_state["tokenizer"], model=model) | |
| elif pipeline_name == "ort_optimized_pipeline": | |
| if not OPTIMIZED_MODEL_PATH.exists(): | |
| optimization_config = OptimizationConfig(optimization_level=99) | |
| optimizer = ORTOptimizer.from_pretrained(HUB_MODEL_PATH, feature="sequence-classification") | |
| optimizer.export(ONNX_MODEL_PATH, OPTIMIZED_MODEL_PATH, optimization_config=optimization_config) | |
| optimizer.model.config.save_pretrained(OPTIMIZED_BASE_PATH) | |
| model = ORTModelForSequenceClassification.from_pretrained( | |
| OPTIMIZED_BASE_PATH, file_name=OPTIMIZED_MODEL_PATH.name | |
| ) | |
| pipeline = ort_pipeline("text-classification", tokenizer=st.session_state["tokenizer"], model=model) | |
| elif pipeline_name == "ort_quantized_pipeline": | |
| if not QUANTIZED_MODEL_PATH.exists(): | |
| quantization_config = AutoQuantizationConfig.arm64(is_static=False, per_channel=False) | |
| quantizer = ORTQuantizer.from_pretrained(HUB_MODEL_PATH, feature="sequence-classification") | |
| quantizer.export(ONNX_MODEL_PATH, QUANTIZED_MODEL_PATH, quantization_config=quantization_config) | |
| quantizer.model.config.save_pretrained(QUANTIZED_BASE_PATH) | |
| model = ORTModelForSequenceClassification.from_pretrained( | |
| QUANTIZED_BASE_PATH, file_name=QUANTIZED_MODEL_PATH.name | |
| ) | |
| pipeline = ort_pipeline("text-classification", tokenizer=st.session_state["tokenizer"], model=model) | |
| print(type(pipeline)) | |
| return pipeline | |
| st.set_page_config(page_title="Optimum Text Classification", page_icon="⭐") | |
| st.title("⭐ Optimum Text Classification") | |
| st.subheader("Classify financial news tone with 🤗 Optimum and ONNXRuntime") | |
| st.markdown(""" | |
| [](https://github.com/ChainYo) | |
| [](https://huggingface.co/ChainYo) | |
| [](https://www.linkedin.com/in/thomas-chaigneau-dev/) | |
| [](https://discord.gg/) | |
| """) | |
| with st.expander("⭐ Details", expanded=True): | |
| st.markdown( | |
| """ | |
| This app is a **demo** of the [🤗 Optimum Text Classification](https://huggingface.co/docs/optimum/onnxruntime/modeling_ort#optimum-inference-with-onnx-runtime) pipeline. | |
| We aim to compare the original pipeline with the ONNXRuntime pipeline. | |
| We use the [Finbert-Tone](https://huggingface.co/yiyanghkust/finbert-tone) model to classify financial news tone for the demo. | |
| You can enter multiple sentences to classify them by separating them with a `; (semicolon)`. | |
| """ | |
| ) | |
| if "init_models" not in st.session_state: | |
| st.session_state["init_models"] = True | |
| if st.session_state["init_models"]: | |
| with st.spinner(text="Loading files and models..."): | |
| loading_logs = st.empty() | |
| with loading_logs.container(): | |
| BASE_PATH.mkdir(exist_ok=True) | |
| QUANTIZED_BASE_PATH.mkdir(exist_ok=True) | |
| OPTIMIZED_BASE_PATH.mkdir(exist_ok=True) | |
| if "tokenizer" not in st.session_state: | |
| tokenizer = BertTokenizer.from_pretrained(HUB_MODEL_PATH) | |
| st.session_state["tokenizer"] = tokenizer | |
| st.text("✅ Tokenizer loaded.") | |
| if "pipeline" not in st.session_state: | |
| for pipeline in VAR2LABEL.keys(): | |
| st.session_state["pipeline"] = load_pipeline(pipeline) | |
| st.text("✅ Models ready.") | |
| sleep(2) | |
| loading_logs.success("🎉 Everything is ready!") | |
| st.session_state["init_models"] = False | |
| if "inference_timers" not in st.session_state: | |
| st.session_state["inference_timers"] = {} | |
| exp_number = st.slider("The number of experiments per model.", min_value=10, max_value=300, value=150) | |
| get_only_mean = st.checkbox("Get only the mean of the inference time for each model.", value=False) | |
| input_text = st.text_area( | |
| "Enter text to classify", | |
| "there is a shortage of capital, and we need extra financing; growth is strong and we have plenty of liquidity; there are doubts about our finances; profits are flat" | |
| ) | |
| run_inference = st.button("🚀 Run inference") | |
| if run_inference: | |
| st.text("🔎 Running inference...") | |
| sentences = input_text.split(";") | |
| st.session_state["inference_timers"] = get_timers(samples=sentences, exp_number=exp_number, only_mean=get_only_mean) | |
| st.plotly_chart(get_plot(st.session_state["inference_timers"]), use_container_width=True) | |