Spaces:
Running
Running
| import gradio as gr | |
| import torchaudio | |
| import pandas as pd | |
| import torch.nn.functional as F | |
| import whisper | |
| import logging | |
| import plotly.express as px | |
| from utils.config_loader import ConfigLoader | |
| from data_loading.feature_extractor import ( | |
| PretrainedAudioEmbeddingExtractor, | |
| PretrainedTextEmbeddingExtractor | |
| ) | |
| import chardet | |
| import torch | |
| from models.models import BiFormer | |
| # DEVICE = torch.device('cuda') | |
| DEVICE = torch.device('cpu') | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| # Constants with emojis and colors | |
| LABEL_TO_EMOTION = { | |
| 0: 'π Anger', | |
| 1: 'π€’ Disgust', | |
| 2: 'π¨ Fear', | |
| 3: 'π Joy/Happiness', | |
| 4: 'π Neutral', | |
| 5: 'π’ Sadness', | |
| 6: 'π² Surprise/Enthusiasm' | |
| } | |
| EMOTION_COLORS = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEEAD', '#FF9999', '#D4A5A5'] | |
| emotion_color_map = {emotion: color for emotion, color in zip(LABEL_TO_EMOTION.values(), EMOTION_COLORS)} | |
| TARGET_SAMPLE_RATE = 16000 | |
| def initialize_components(config_path='config.toml'): | |
| """Initialize configuration and models.""" | |
| config = ConfigLoader(config_path) | |
| config.show_config() | |
| model = BiFormer( | |
| audio_dim=256, | |
| text_dim=1024, | |
| seg_len=95, | |
| hidden_dim=256, | |
| hidden_dim_gated=256, | |
| num_transformer_heads=8, | |
| num_graph_heads=2, | |
| positional_encoding=False, | |
| dropout=0.15, | |
| mode='mean', | |
| tr_layer_number=5, | |
| out_features=256, | |
| num_classes=7 | |
| ) | |
| checkpoint_path = "best_model_dev_0_5895_epoch_8.pt" | |
| state = torch.load(checkpoint_path, map_location="cpu") | |
| model.load_state_dict(state) | |
| model = model.to(DEVICE) | |
| model.eval() | |
| return ( | |
| PretrainedAudioEmbeddingExtractor(config), | |
| PretrainedTextEmbeddingExtractor(config), | |
| whisper.load_model("base"), | |
| model | |
| ) | |
| audio_extractor, text_extractor, whisper_model, bimodal_model = initialize_components() | |
| def load_and_preprocess_audio(audio_path): | |
| """Load and preprocess audio to mono 16kHz format.""" | |
| try: | |
| waveform, orig_sr = torchaudio.load(audio_path) | |
| waveform = waveform.mean(dim=0, keepdim=False) | |
| if orig_sr != TARGET_SAMPLE_RATE: | |
| resampler = torchaudio.transforms.Resample( | |
| orig_freq=orig_sr, | |
| new_freq=TARGET_SAMPLE_RATE | |
| ) | |
| waveform = resampler(waveform) | |
| return waveform, TARGET_SAMPLE_RATE | |
| except Exception as e: | |
| logging.error(f"Audio loading failed: {e}") | |
| raise | |
| def transcribe_audio(audio_path): | |
| """Convert speech to text using Whisper.""" | |
| try: | |
| result = whisper_model.transcribe(audio_path, fp16=False) | |
| return result.get('text', '') | |
| except Exception as e: | |
| logging.error(f"Transcription failed: {e}") | |
| return "" | |
| def get_predictions(input_data, extractor, is_audio=False): | |
| """Generic prediction function for audio/text.""" | |
| try: | |
| if is_audio: | |
| pred, emb = extractor.extract(input_data, TARGET_SAMPLE_RATE) | |
| else: | |
| pred, emb = extractor.extract(input_data) | |
| return F.softmax(pred, dim=-1)[0].tolist(), emb | |
| except Exception as e: | |
| logging.error(f"Prediction failed: {e}") | |
| return [0.0] * len(LABEL_TO_EMOTION), None | |
| def create_emotion_df(probabilities): | |
| """Create sorted emotion probability dataframe with percentages.""" | |
| df = pd.DataFrame({ | |
| 'Emotion': list(LABEL_TO_EMOTION.values()), | |
| 'Probability': [round(p*100, 2) for p in probabilities] | |
| }) | |
| return df | |
| def create_plot(df, title): | |
| """Create Plotly bar chart with proper formatting.""" | |
| fig = px.bar( | |
| df, | |
| x='Emotion', | |
| y='Probability', | |
| title=title, | |
| color='Emotion', | |
| color_discrete_map=emotion_color_map | |
| ) | |
| fig.update_layout( | |
| xaxis=dict(tickangle=-45, tickfont=dict(size=12)), | |
| yaxis=dict(title='Probability (%)'), | |
| margin=dict(l=20, r=20, t=60, b=100), | |
| height=400, | |
| showlegend=False | |
| ) | |
| return fig | |
| def get_top_emotion(probabilities): | |
| """Return formatted top emotion with percentage.""" | |
| max_idx = probabilities.index(max(probabilities)) | |
| return f"{LABEL_TO_EMOTION[max_idx]} ({max(probabilities)*100:.1f}%)" | |
| def process_audio(audio_path): | |
| """Main processing pipeline.""" | |
| try: | |
| if not audio_path: | |
| empty = create_emotion_df([0]*len(LABEL_TO_EMOTION)) | |
| return ( | |
| create_plot(empty, "π§ Audio Analysis"), | |
| "No audio detected", | |
| create_plot(empty, "π Text Analysis"), | |
| create_plot(empty, "π€ Audio-Text Analysis"), | |
| "π Please provide audio input" | |
| ) | |
| # Audio processing | |
| waveform, sr = load_and_preprocess_audio(audio_path) | |
| audio_probs, audio_features = get_predictions(waveform, audio_extractor, is_audio=True) | |
| audio_df = create_emotion_df(audio_probs) | |
| # Text processing | |
| text = transcribe_audio(audio_path) | |
| text_probs, text_features = get_predictions(text, text_extractor) if text.strip() else [0.0]*len(LABEL_TO_EMOTION) | |
| text_df = create_emotion_df(text_probs) | |
| # Combined results | |
| combined_probs = bimodal_model(audio_features, text_features) | |
| combined_probs = F.softmax(combined_probs, dim=-1)[0].detach().cpu().numpy().tolist() | |
| combined_df = create_emotion_df(combined_probs) | |
| top_emotion = get_top_emotion(combined_probs) | |
| return ( | |
| create_plot(audio_df, "π§ Audio Analysis"), | |
| f"π£οΈ Transcription:\n{text}", | |
| create_plot(text_df, "π Text Analysis"), | |
| create_plot(combined_df, "π€ Audio-Text Analysis"), | |
| f"## π Dominant Emotion: {top_emotion}" | |
| ) | |
| except Exception as e: | |
| logging.error(f"Processing failed: {e}") | |
| error_df = create_emotion_df([0]*len(LABEL_TO_EMOTION)) | |
| return ( | |
| create_plot(error_df, "π§ Audio Analysis"), | |
| "β Error processing audio", | |
| create_plot(error_df, "π Text Analysis"), | |
| create_plot(error_df, "π€ Audio-Text Analysis"), | |
| "β οΈ Processing Error" | |
| ) | |
| def create_app(): | |
| """Build enhanced Gradio interface.""" | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Emotion Detection from Speech") as demo: | |
| gr.Markdown("# Intelligent system for Bilingual Bimodal Emotion Recognition (BiBiER)") | |
| gr.Markdown("Analyze emotions in Russian and English speech through both audio characteristics and spoken content") | |
| with gr.Row(): | |
| audio_input = gr.Audio( | |
| sources=["upload", "microphone"], | |
| type="filepath", | |
| label="Record or Upload Audio", | |
| format="wav", | |
| interactive=True | |
| ) | |
| with gr.Row(): | |
| top_emotion = gr.Markdown("## π Dominant Emotion: Waiting for input ...", | |
| elem_classes="dominant-emotion") | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_plot = gr.Plot(label="Audio Analysis") | |
| with gr.Column(): | |
| text_plot = gr.Plot(label="Text Analysis") | |
| with gr.Column(): | |
| combined_plot = gr.Plot(label="Audio-Text Analysis") | |
| transcription = gr.Textbox( | |
| label="π Transcription Results", | |
| placeholder="Transcribed text will appear here...", | |
| lines=3, | |
| max_lines=6 | |
| ) | |
| audio_input.change( | |
| process_audio, | |
| inputs=audio_input, | |
| outputs=[audio_plot, transcription, text_plot, combined_plot, top_emotion] | |
| ) | |
| return demo | |
| def create_authors(): | |
| df = pd.DataFrame({ | |
| "Name": ["Author", "Author"] | |
| }) | |
| with gr.Blocks() as demo: | |
| gr.Dataframe(df) | |
| return demo | |
| def create_reqs(): | |
| """Create requirements tab with formatted data and explanations.""" | |
| # 1οΈβ£ Detect file encoding | |
| with open('requirements.txt', 'rb') as f: | |
| raw_data = f.read() | |
| encoding = chardet.detect(raw_data)['encoding'] | |
| # 2οΈβ£ Parse requirements into library-version pairs | |
| def parse_requirements(lines): | |
| requirements = [] | |
| for line in lines: | |
| line = line.strip() | |
| if not line or line.startswith('#'): | |
| continue # Skip empty lines and comments | |
| parts = line.split('==') | |
| library = parts[0].strip() | |
| version = parts[1].strip() if len(parts) > 1 else 'latest' | |
| requirements.append((library, version)) | |
| return requirements | |
| # 3οΈβ£ Load and process requirements | |
| with open('requirements.txt', 'r', encoding=encoding) as f: | |
| requirements = parse_requirements(f.readlines()) | |
| # 4οΈβ£ Create structured data for display | |
| df = pd.DataFrame({ | |
| "π¦ Library": [lib for lib, _ in requirements], | |
| "π Recommended Version": [ver for _, ver in requirements] | |
| }) | |
| # 5οΈβ£ Build interactive components | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π¦ Dependency Requirements") | |
| gr.Markdown(""" | |
| ## Essential Packages for Operation | |
| These are the core libraries and versions needed to run the application successfully: | |
| """) | |
| gr.Dataframe( | |
| df, | |
| interactive=True, | |
| wrap=True, | |
| elem_id="requirements-table" | |
| ) | |
| gr.Markdown("_Note: Versions marked 'latest' can use any compatible version_") | |
| return demo | |
| def create_demo(): | |
| app = create_app() | |
| authors = create_authors() | |
| reqs = create_reqs() | |
| demo = gr.TabbedInterface( | |
| [app, authors, reqs], | |
| tab_names=["β App", "π Authors", "π Requirements"] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.launch() | |