Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, AutoModelForQuestionAnswering, AutoTokenizer, pipeline | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| def process_inputs(audio, option, question=None): | |
| # Process inputs and return results | |
| if option == "Translate": | |
| generated_text = generate_text_from_audio(audio), None | |
| return generated_text | |
| elif option == "Summarize": | |
| generated_text = generate_text_from_audio(audio) | |
| return generate_summary_from_text(generated_text, minLength=20, maxLength=150), None | |
| elif option == "text-classification": | |
| generated_text = generate_text_from_audio(audio) | |
| return "", text_classification(generated_text) | |
| elif option == "Ask a Question": | |
| generated_text = generate_text_from_audio(audio) | |
| return ask_ques_from_text(generated_text, question), None | |
| def generate_text_from_audio(audio): | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| model_id = "openai/whisper-small" | |
| model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True | |
| ) | |
| model.to(device) | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| # Load the audio using librosa and extract the audio data (not the sample rate) | |
| audio_data = audio # audio_data is the NumPy array we need | |
| pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model=model, | |
| tokenizer=processor.tokenizer, | |
| feature_extractor=processor.feature_extractor, | |
| torch_dtype=torch_dtype, | |
| chunk_length_s=30, | |
| batch_size=16, # batch size for inference - set based on your device | |
| device=device, | |
| ) | |
| audio_text_result = pipe(audio_data, generate_kwargs={"task": "translate", "forced_decoder_ids": [[1, None], [2, 50359]]}) | |
| return audio_text_result["text"] | |
| def generate_summary_from_text(text, minLength, maxLength): | |
| summarizer = pipeline("summarization", model="Falconsai/text_summarization") | |
| return summarizer(text, max_length=maxLength, min_length=minLength, do_sample=False)[0]['summary_text'] | |
| def text_classification(text): | |
| classifier = pipeline(task="text-classification", model="SamLowe/roberta-base-go_emotions", top_k=None) | |
| model_outputs = classifier([text]) | |
| # Extract the labels and scores from the model's output | |
| labels = [output['label'] for output in model_outputs[0]] | |
| scores = [output['score'] for output in model_outputs[0]] | |
| sorted_data = sorted(zip(scores, labels), reverse=True) | |
| # Extract top 5 emotions | |
| top_5_scores, top_5_labels = zip(*sorted_data[:5]) | |
| # Plotting the Bar Chart | |
| plt.figure(figsize=(12, 8)) | |
| plt.barh(top_5_labels, top_5_scores, color='skyblue') | |
| plt.title('Top 5 Sentiment Scores for Emotions') | |
| plt.xlabel('Score') | |
| plt.ylabel('Emotion') | |
| # Display the plot | |
| plt.savefig("classification_plot.png") | |
| plt.close() | |
| return "classification_plot.png" | |
| def ask_ques_from_text(text, ques): | |
| model_name = "deepset/roberta-base-squad2" | |
| # Get predictions | |
| nlp = pipeline('question-answering', model=model_name, tokenizer=model_name, device=0) | |
| QA_input = { | |
| 'question': ques, | |
| 'context': text # Your context text from audio_text_result | |
| } | |
| res = nlp(QA_input) | |
| print("Answer from pipeline:", res['answer']) | |
| return res['answer'] | |
| demo = gr.Interface( | |
| fn=process_inputs, | |
| inputs=[ | |
| gr.Audio(label="Upload audio in .mp3 format", type="filepath"), # Audio input | |
| gr.Dropdown(choices=["Translate", "Summarize", "text-classification", "Ask a Question"], label="Choose an Option"), | |
| gr.Textbox(label="Enter your question if you chose Ask a question in dropdown", placeholder="Enter your question here", visible=True) | |
| ], | |
| outputs=[gr.Textbox(label="Result"), gr.Image(label="Classification Plot")], | |
| ) | |
| demo.launch() | |