Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import json | |
| import pandas as pd | |
| import numpy as np | |
| import plotly.express as px | |
| from io import StringIO | |
| import time | |
| def model_inference_dashboard(model_info): | |
| """Create a dashboard for testing model inference directly in the app""" | |
| if not model_info: | |
| st.error("Model information not found") | |
| return | |
| st.subheader("🧠 Model Inference Dashboard") | |
| # Get the pipeline type based on model tags or information | |
| pipeline_tag = getattr(model_info, "pipeline_tag", None) | |
| if not pipeline_tag: | |
| # Try to determine from tags | |
| tags = getattr(model_info, "tags", []) | |
| for tag in tags: | |
| if tag in [ | |
| "text-classification", "token-classification", "question-answering", | |
| "summarization", "translation", "text-generation", "fill-mask", | |
| "sentence-similarity", "image-classification", "object-detection", | |
| "image-segmentation", "text-to-image", "image-to-text" | |
| ]: | |
| pipeline_tag = tag | |
| break | |
| if not pipeline_tag: | |
| pipeline_tag = "text-classification" # Default fallback | |
| # Display information about the model | |
| st.info(f"This dashboard allows you to test your model's inference capabilities. Model pipeline: **{pipeline_tag}**") | |
| # Different input options based on pipeline type | |
| input_data = None | |
| if pipeline_tag in ["text-classification", "token-classification", "fill-mask", "text-generation", "summarization"]: | |
| # Text-based input | |
| st.markdown("### Text Input") | |
| input_text = st.text_area( | |
| "Enter text for inference", | |
| value="This model is amazing!", | |
| height=150 | |
| ) | |
| # Additional parameters for specific pipelines | |
| if pipeline_tag == "text-generation": | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| max_length = st.slider("Max Length", min_value=10, max_value=500, value=100) | |
| with col2: | |
| temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=1.0, step=0.1) | |
| input_data = { | |
| "text": input_text, | |
| "max_length": max_length, | |
| "temperature": temperature | |
| } | |
| elif pipeline_tag == "summarization": | |
| max_length = st.slider("Max Summary Length", min_value=10, max_value=200, value=50) | |
| input_data = { | |
| "text": input_text, | |
| "max_length": max_length | |
| } | |
| else: | |
| input_data = {"text": input_text} | |
| elif pipeline_tag in ["question-answering"]: | |
| st.markdown("### Question & Context") | |
| question = st.text_input("Question", value="What is this model about?") | |
| context = st.text_area( | |
| "Context", | |
| value="This model is a transformer-based language model designed for natural language understanding tasks.", | |
| height=150 | |
| ) | |
| input_data = { | |
| "question": question, | |
| "context": context | |
| } | |
| elif pipeline_tag in ["translation"]: | |
| st.markdown("### Translation") | |
| source_lang = st.selectbox("Source Language", ["English", "French", "German", "Spanish", "Chinese"]) | |
| target_lang = st.selectbox("Target Language", ["French", "English", "German", "Spanish", "Chinese"]) | |
| translation_text = st.text_area("Text to translate", value="Hello, how are you?", height=150) | |
| input_data = { | |
| "text": translation_text, | |
| "source_language": source_lang, | |
| "target_language": target_lang | |
| } | |
| elif pipeline_tag in ["image-classification", "object-detection", "image-segmentation"]: | |
| st.markdown("### Image Input") | |
| upload_method = st.radio("Select input method", ["Upload Image", "Image URL"]) | |
| if upload_method == "Upload Image": | |
| uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file is not None: | |
| st.image(uploaded_file, caption="Uploaded Image", use_column_width=True) | |
| input_data = {"image": uploaded_file} | |
| else: | |
| image_url = st.text_input("Image URL", value="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/distilbert-base-uncased-finetuned-sst-2-english-architecture.png") | |
| if image_url: | |
| st.image(image_url, caption="Image from URL", use_column_width=True) | |
| input_data = {"image_url": image_url} | |
| elif pipeline_tag in ["audio-classification", "automatic-speech-recognition"]: | |
| st.markdown("### Audio Input") | |
| upload_method = st.radio("Select input method", ["Upload Audio", "Audio URL"]) | |
| if upload_method == "Upload Audio": | |
| uploaded_file = st.file_uploader("Upload an audio file", type=["mp3", "wav", "ogg"]) | |
| if uploaded_file is not None: | |
| st.audio(uploaded_file) | |
| input_data = {"audio": uploaded_file} | |
| else: | |
| audio_url = st.text_input("Audio URL") | |
| if audio_url: | |
| st.audio(audio_url) | |
| input_data = {"audio_url": audio_url} | |
| # Execute inference | |
| if st.button("Run Inference", use_container_width=True): | |
| if input_data: | |
| with st.spinner("Running inference..."): | |
| # In a real implementation, this would call the HF Inference API | |
| # For demo purposes, simulate a response | |
| time.sleep(2) | |
| # Generate a sample response based on the pipeline type | |
| if pipeline_tag == "text-classification": | |
| result = [ | |
| {"label": "POSITIVE", "score": 0.9231}, | |
| {"label": "NEGATIVE", "score": 0.0769} | |
| ] | |
| elif pipeline_tag == "token-classification": | |
| result = [ | |
| {"entity": "B-PER", "word": "This", "score": 0.2, "index": 0, "start": 0, "end": 4}, | |
| {"entity": "O", "word": "model", "score": 0.95, "index": 1, "start": 5, "end": 10}, | |
| {"entity": "O", "word": "is", "score": 0.99, "index": 2, "start": 11, "end": 13}, | |
| {"entity": "B-MISC", "word": "amazing", "score": 0.85, "index": 3, "start": 14, "end": 21} | |
| ] | |
| elif pipeline_tag == "text-generation": | |
| result = { | |
| "generated_text": input_data["text"] + " It provides state-of-the-art performance on a wide range of natural language processing tasks, including sentiment analysis, named entity recognition, and question answering. The model was trained on a diverse corpus of text data, allowing it to generate coherent and contextually relevant responses." | |
| } | |
| elif pipeline_tag == "summarization": | |
| result = { | |
| "summary_text": "This model provides excellent performance." | |
| } | |
| elif pipeline_tag == "question-answering": | |
| result = { | |
| "answer": "a transformer-based language model", | |
| "start": 9, | |
| "end": 45, | |
| "score": 0.953 | |
| } | |
| elif pipeline_tag == "translation": | |
| if input_data["target_language"] == "French": | |
| result = {"translation_text": "Bonjour, comment allez-vous?"} | |
| elif input_data["target_language"] == "German": | |
| result = {"translation_text": "Hallo, wie geht es dir?"} | |
| elif input_data["target_language"] == "Spanish": | |
| result = {"translation_text": "Hola, ¿cómo estás?"} | |
| elif input_data["target_language"] == "Chinese": | |
| result = {"translation_text": "你好,你好吗?"} | |
| else: | |
| result = {"translation_text": "Hello, how are you?"} | |
| elif pipeline_tag in ["image-classification"]: | |
| result = [ | |
| {"label": "diagram", "score": 0.9712}, | |
| {"label": "architecture", "score": 0.0231}, | |
| {"label": "document", "score": 0.0057} | |
| ] | |
| elif pipeline_tag in ["object-detection"]: | |
| result = [ | |
| {"label": "box", "score": 0.9712, "box": {"xmin": 10, "ymin": 20, "xmax": 100, "ymax": 80}}, | |
| {"label": "text", "score": 0.8923, "box": {"xmin": 120, "ymin": 30, "xmax": 250, "ymax": 60}} | |
| ] | |
| else: | |
| result = {"result": "Sample response for " + pipeline_tag} | |
| # Display the results | |
| st.markdown("### Inference Results") | |
| # Different visualizations based on the response type | |
| if pipeline_tag == "text-classification": | |
| # Create a bar chart for classification results | |
| result_df = pd.DataFrame(result) | |
| fig = px.bar( | |
| result_df, | |
| x="label", | |
| y="score", | |
| color="score", | |
| color_continuous_scale=px.colors.sequential.Viridis, | |
| title="Classification Results" | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Show the raw results | |
| st.json(result) | |
| elif pipeline_tag == "token-classification": | |
| # Display entity highlighting | |
| st.markdown("#### Named Entities") | |
| # Create HTML with colored spans for entities | |
| html = "" | |
| input_text = input_data["text"] | |
| entities = {} | |
| for item in result: | |
| if item["entity"].startswith("B-") or item["entity"].startswith("I-"): | |
| entity_type = item["entity"][2:] # Remove B- or I- prefix | |
| entities[entity_type] = entities.get(entity_type, 0) + 1 | |
| # Create a color map for entity types | |
| colors = px.colors.qualitative.Plotly[:len(entities)] | |
| entity_colors = dict(zip(entities.keys(), colors)) | |
| # Create the HTML | |
| for item in result: | |
| word = item["word"] | |
| entity = item["entity"] | |
| if entity == "O": | |
| html += f"{word} " | |
| else: | |
| entity_type = entity[2:] if entity.startswith("B-") or entity.startswith("I-") else entity | |
| color = entity_colors.get(entity_type, "#CCCCCC") | |
| html += f'<span style="background-color: {color}; padding: 2px; border-radius: 3px;" title="{entity} ({item["score"]:.2f})">{word}</span> ' | |
| st.markdown(f'<div style="line-height: 2.5;">{html}</div>', unsafe_allow_html=True) | |
| # Display legend | |
| st.markdown("#### Entity Legend") | |
| legend_html = "".join([ | |
| f'<span style="background-color: {color}; padding: 2px 8px; margin-right: 10px; border-radius: 3px;">{entity}</span>' | |
| for entity, color in entity_colors.items() | |
| ]) | |
| st.markdown(f'<div>{legend_html}</div>', unsafe_allow_html=True) | |
| # Show the raw results | |
| st.json(result) | |
| elif pipeline_tag in ["text-generation", "summarization", "translation"]: | |
| # Display the generated text | |
| response_key = "generated_text" if "generated_text" in result else "summary_text" if "summary_text" in result else "translation_text" | |
| st.markdown(f"#### Output Text") | |
| st.markdown(f'<div style="background-color: #f0f2f6; padding: 20px; border-radius: 10px;">{result[response_key]}</div>', unsafe_allow_html=True) | |
| # Text stats | |
| st.markdown("#### Text Statistics") | |
| input_length = len(input_data["text"]) if "text" in input_data else 0 | |
| output_length = len(result[response_key]) | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.metric("Input Length", input_length, "characters") | |
| with col2: | |
| st.metric("Output Length", output_length, "characters") | |
| with col3: | |
| compression = ((output_length - input_length) / input_length * 100) if input_length > 0 else 0 | |
| st.metric("Length Change", f"{compression:.1f}%", f"{output_length - input_length} chars") | |
| elif pipeline_tag == "question-answering": | |
| # Highlight the answer in the context | |
| st.markdown("#### Answer") | |
| st.markdown(f'<div style="background-color: #e6f3ff; padding: 10px; border-radius: 5px; font-weight: bold;">{result["answer"]}</div>', unsafe_allow_html=True) | |
| # Show the answer in context | |
| if "context" in input_data: | |
| st.markdown("#### Answer in Context") | |
| context = input_data["context"] | |
| start = result["start"] | |
| end = result["end"] | |
| highlighted_context = ( | |
| context[:start] + | |
| f'<span style="background-color: #ffeb3b; font-weight: bold;">{context[start:end]}</span>' + | |
| context[end:] | |
| ) | |
| st.markdown(f'<div style="background-color: #f0f2f6; padding: 15px; border-radius: 10px; line-height: 1.5;">{highlighted_context}</div>', unsafe_allow_html=True) | |
| # Confidence score | |
| st.markdown("#### Confidence") | |
| st.progress(result["score"]) | |
| st.text(f"Confidence Score: {result['score']:.4f}") | |
| elif pipeline_tag == "image-classification": | |
| # Create a bar chart for classification results | |
| result_df = pd.DataFrame(result) | |
| fig = px.bar( | |
| result_df, | |
| x="score", | |
| y="label", | |
| orientation='h', | |
| color="score", | |
| color_continuous_scale=px.colors.sequential.Viridis, | |
| title="Image Classification Results" | |
| ) | |
| fig.update_layout(yaxis={'categoryorder':'total ascending'}) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Show the raw results | |
| st.json(result) | |
| else: | |
| # Generic display for other types | |
| st.json(result) | |
| # Option to save the results | |
| st.download_button( | |
| label="Download Results", | |
| data=json.dumps(result, indent=2), | |
| file_name="inference_results.json", | |
| mime="application/json" | |
| ) | |
| else: | |
| st.warning("Please provide input data for inference") | |
| # API integration options | |
| with st.expander("API Integration"): | |
| st.markdown("### Use this model in your application") | |
| # Python code example | |
| st.markdown("#### Python") | |
| python_code = f""" | |
| ```python | |
| import requests | |
| API_URL = "https://api-inference.huggingface.co/models/{model_info.modelId}" | |
| headers = {{"Authorization": "Bearer YOUR_API_KEY"}} | |
| def query(payload): | |
| response = requests.post(API_URL, headers=headers, json=payload) | |
| return response.json() | |
| # Example usage | |
| output = query({{ | |
| "inputs": "This model is amazing!" | |
| }}) | |
| print(output) | |
| ``` | |
| """ | |
| st.markdown(python_code) | |
| # JavaScript code example | |
| st.markdown("#### JavaScript") | |
| js_code = f""" | |
| ```javascript | |
| async function query(data) {{ | |
| const response = await fetch( | |
| "https://api-inference.huggingface.co/models/{model_info.modelId}", | |
| {{ | |
| headers: {{ Authorization: "Bearer YOUR_API_KEY" }}, | |
| method: "POST", | |
| body: JSON.stringify(data), | |
| }} | |
| ); | |
| const result = await response.json(); | |
| return result; | |
| }} | |
| // Example usage | |
| query({{"inputs": "This model is amazing!"}}).then((response) => {{ | |
| console.log(JSON.stringify(response)); | |
| }}); | |
| ``` | |
| """ | |
| st.markdown(js_code) | |