Spaces:
Running
Running
| ########################################## | |
| # Step 0: Import required libraries | |
| ########################################## | |
| import streamlit as st # Web interface framework | |
| from transformers import ( | |
| pipeline, | |
| SpeechT5Processor, | |
| SpeechT5ForTextToSpeech, | |
| SpeechT5HifiGan, | |
| AutoModelForCausalLM, | |
| AutoTokenizer | |
| ) # AI model components | |
| from datasets import load_dataset # Voice embeddings | |
| import torch # Tensor computation | |
| import soundfile as sf # Audio file handling | |
| import time # Execution timing | |
| ########################################## | |
| # Initial configuration (MUST be first) | |
| ########################################## | |
| st.set_page_config( | |
| page_title="Just Comment", | |
| page_icon="💬", | |
| layout="centered", | |
| initial_sidebar_state="collapsed" | |
| ) | |
| ########################################## | |
| # Optimized model loading with caching | |
| ########################################## | |
| def _load_models(): | |
| """Load and cache models with maximum optimization""" | |
| # Initialize device-agnostic model loading | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load emotion classifier with optimized settings | |
| emotion_pipe = pipeline( | |
| "text-classification", | |
| model="Thea231/jhartmann_emotion_finetuning", | |
| device=device, | |
| truncation=True, | |
| padding=True | |
| ) | |
| # Load text generation model with 4-bit quantization | |
| textgen_tokenizer = AutoTokenizer.from_pretrained( | |
| "Qwen/Qwen1.5-0.5B", | |
| use_fast=True | |
| ) | |
| textgen_model = AutoModelForCausalLM.from_pretrained( | |
| "Qwen/Qwen1.5-0.5B", | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| # Load TTS components with hardware acceleration | |
| tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") | |
| tts_model = SpeechT5ForTextToSpeech.from_pretrained( | |
| "microsoft/speecht5_tts", | |
| torch_dtype=torch.float16 | |
| ).to(device) | |
| tts_vocoder = SpeechT5HifiGan.from_pretrained( | |
| "microsoft/speecht5_hifigan", | |
| torch_dtype=torch.float16 | |
| ).to(device) | |
| # Preload speaker embeddings | |
| speaker_embeddings = torch.tensor( | |
| load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"] | |
| ).unsqueeze(0).to(device) | |
| return { | |
| 'emotion': emotion_pipe, | |
| 'textgen_tokenizer': textgen_tokenizer, | |
| 'textgen_model': textgen_model, | |
| 'tts_processor': tts_processor, | |
| 'tts_model': tts_model, | |
| 'tts_vocoder': tts_vocoder, | |
| 'speaker_embeddings': speaker_embeddings, | |
| 'device': device | |
| } | |
| ########################################## | |
| # UI Components | |
| ########################################## | |
| def _display_interface(): | |
| """Render optimized user interface""" | |
| st.title("Just Comment") | |
| st.markdown(f"### I'm listening to you, my friend~") # f-string usage | |
| return st.text_area( | |
| "📝 Enter your comment:", | |
| placeholder="Type your message here...", | |
| height=150, | |
| key="user_input" | |
| ) | |
| ########################################## | |
| # Core Processing Functions | |
| ########################################## | |
| def _analyze_emotion(text, classifier): | |
| """Fast emotion analysis with early stopping""" | |
| start_time = time.time() | |
| results = classifier(text[:512], return_all_scores=True)[0] # Limit input length | |
| valid_emotions = {'sadness', 'joy', 'love', 'anger', 'fear', 'surprise'} | |
| # Find dominant emotion | |
| dominant = max( | |
| (e for e in results if e['label'].lower() in valid_emotions), | |
| key=lambda x: x['score'], | |
| default={'label': 'neutral', 'score': 1.0} | |
| ) | |
| st.write(f"⏱️ Emotion analysis time: {time.time()-start_time:.2f}s") | |
| return dominant | |
| def _generate_prompt(text, emotion): | |
| """Optimized prompt templates for all emotions""" | |
| prompt_templates = { | |
| "sadness": f"Sadness detected: {{input}}\nRespond with: 1. Empathy 2. Support 3. Solution\nResponse:", | |
| "joy": f"Joy detected: {{input}}\nRespond with: 1. Thanks 2. Appreciation 3. Engagement\nResponse:", | |
| "love": f"Love detected: {{input}}\nRespond with: 1. Warmth 2. Community 3. Exclusive Offer\nResponse:", | |
| "anger": f"Anger detected: {{input}}\nRespond with: 1. Apology 2. Action 3. Compensation\nResponse:", | |
| "fear": f"Fear detected: {{input}}\nRespond with: 1. Reassurance 2. Safety 3. Support\nResponse:", | |
| "surprise": f"Surprise detected: {{input}}\nRespond with: 1. Acknowledgement 2. Solution 3. Follow-up\nResponse:", | |
| "neutral": f"Feedback: {{input}}\nRespond professionally:\n1. Acknowledgement\n2. Assistance\n3. Next Steps\nResponse:" | |
| } | |
| return prompt_templates[emotion.lower()].format(input=text[:300]) # Limit input length | |
| def _process_response(raw_text): | |
| """Fast response processing with validation""" | |
| # Extract response after last marker | |
| response = raw_text.split("Response:")[-1].strip() | |
| # Ensure complete sentences | |
| if '.' in response: | |
| response = response.rsplit('.', 1)[0] + '.' | |
| # Length control | |
| return response[:200] if len(response) > 50 else "Thank you for your feedback. We'll respond shortly." | |
| def _generate_text(user_input, models): | |
| """Ultra-fast text generation pipeline""" | |
| start_time = time.time() | |
| # Emotion analysis | |
| emotion = _analyze_emotion(user_input, models['emotion']) | |
| # Generate prompt | |
| prompt = _generate_prompt(user_input, emotion['label']) | |
| # Tokenize and generate | |
| inputs = models['textgen_tokenizer']( | |
| prompt, | |
| return_tensors="pt", | |
| max_length=128, | |
| truncation=True | |
| ).to(models['device']) | |
| outputs = models['textgen_model'].generate( | |
| inputs.input_ids, | |
| max_new_tokens=80, # Strict limit for speed | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=models['textgen_tokenizer'].eos_token_id | |
| ) | |
| # Decode and process | |
| generated = models['textgen_tokenizer'].decode( | |
| outputs[0], | |
| skip_special_tokens=True | |
| ) | |
| st.write(f"⏱️ Text generation time: {time.time()-start_time:.2f}s") | |
| return _process_response(generated) | |
| def _generate_speech(text, models): | |
| """Hardware-accelerated speech synthesis""" | |
| start_time = time.time() | |
| # Process text | |
| inputs = models['tts_processor']( | |
| text=text[:150], # Limit text length | |
| return_tensors="pt" | |
| ).to(models['device']) | |
| # Generate audio | |
| with torch.inference_mode(): | |
| spectrogram = models['tts_model'].generate_speech( | |
| inputs["input_ids"], | |
| models['speaker_embeddings'] | |
| ) | |
| waveform = models['tts_vocoder'](spectrogram) | |
| # Save optimized audio file | |
| sf.write("response.wav", waveform.cpu().numpy(), 16000) | |
| st.write(f"⏱️ Speech synthesis time: {time.time()-start_time:.2f}s") | |
| return "response.wav" | |
| ########################################## | |
| # Main Application Flow | |
| ########################################## | |
| def main(): | |
| """Optimized execution flow""" | |
| # Load models first | |
| ml_models = _load_models() | |
| # Display interface | |
| user_input = _display_interface() | |
| if user_input: | |
| total_start = time.time() | |
| # Text generation | |
| with st.spinner("🚀 Analyzing & generating response..."): | |
| text_response = _generate_text(user_input, ml_models) | |
| # Display results | |
| st.subheader(f"📄 Generated Response") | |
| st.markdown(f"```\n{text_response}\n```") | |
| # Audio generation | |
| with st.spinner("🔊 Converting to speech..."): | |
| audio_file = _generate_speech(text_response, ml_models) | |
| st.audio(audio_file, format="audio/wav") | |
| st.write(f"⏱️ Total execution time: {time.time()-total_start:.2f}s") | |
| if __name__ == "__main__": | |
| main() |