Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import random | |
| import string | |
| import numpy as np | |
| import gradio as gr | |
| import requests | |
| import soundfile as sf | |
| from transformers import pipeline, set_seed | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import logging | |
| import sys | |
| import gradio as gr | |
| from transformers import pipeline, AutoModelForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM | |
| DEBUG = os.environ.get("DEBUG", "false")[0] in "ty1" | |
| MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024)) | |
| DEFAULT_LANG = os.environ.get("DEFAULT_LANG", "English") | |
| HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN", None) | |
| HEADER = """ | |
| # Poor Man's Duplex | |
| Talk to a language model like you talk on a Walkie-Talkie! Well, with larger latencies. | |
| The models are [EleutherAI's GPT-J-6B](https://huggingface.co/EleutherAI/gpt-j-6B) for English, and [BERTIN GPT-J-6B](https://huggingface.co/bertin-project/bertin-gpt-j-6B) for Spanish. | |
| """.strip() | |
| FOOTER = """ | |
| <div align=center> | |
| <img src="https://visitor-badge.glitch.me/badge?page_id=versae/poor-mans-duplex"/> | |
| <div align=center> | |
| """.strip() | |
| asr_model_name_es = "jonatasgrosman/wav2vec2-large-xlsr-53-spanish" | |
| model_instance_es = AutoModelForCTC.from_pretrained(asr_model_name_es, use_auth_token=HF_AUTH_TOKEN) | |
| processor_es = Wav2Vec2ProcessorWithLM.from_pretrained(asr_model_name_es, use_auth_token=HF_AUTH_TOKEN) | |
| asr_es = pipeline( | |
| "automatic-speech-recognition", | |
| model=model_instance_es, | |
| tokenizer=processor_es.tokenizer, | |
| feature_extractor=processor_es.feature_extractor, | |
| decoder=processor_es.decoder | |
| ) | |
| tts_model_name = "facebook/tts_transformer-es-css10" | |
| speak_es = gr.Interface.load(f"huggingface/{tts_model_name}", api_key=HF_AUTH_TOKEN) | |
| transcribe_es = lambda input_file: asr_es(input_file, chunk_length_s=5, stride_length_s=1)["text"] | |
| def generate_es(text, **kwargs): | |
| # text="Promtp", max_length=100, top_k=100, top_p=50, temperature=0.95, do_sample=True, do_clean=True | |
| api_uri = "https://hf.space/embed/bertin-project/bertin-gpt-j-6B/+/api/predict/" | |
| response = requests.post(api_uri, data=json.dumps({"data": [text, kwargs["max_length"], 100, 50, 0.95, True, True]})) | |
| if response.ok: | |
| if DEBUG: | |
| print("Spanish response >", response.json()) | |
| return response.json()["data"][0] | |
| else: | |
| return "" | |
| asr_model_name_en = "jonatasgrosman/wav2vec2-large-xlsr-53-english" | |
| model_instance_en = AutoModelForCTC.from_pretrained(asr_model_name_en) | |
| processor_en = Wav2Vec2ProcessorWithLM.from_pretrained(asr_model_name_en) | |
| asr_en = pipeline( | |
| "automatic-speech-recognition", | |
| model=model_instance_en, | |
| tokenizer=processor_en.tokenizer, | |
| feature_extractor=processor_en.feature_extractor, | |
| decoder=processor_en.decoder | |
| ) | |
| tts_model_name = "facebook/fastspeech2-en-ljspeech" | |
| speak_en = gr.Interface.load(f"huggingface/{tts_model_name}", api_key=HF_AUTH_TOKEN) | |
| transcribe_en = lambda input_file: asr_en(input_file, chunk_length_s=5, stride_length_s=1)["text"] | |
| # generate_iface = gr.Interface.load("huggingface/EleutherAI/gpt-j-6B", api_key=HF_AUTH_TOKEN) | |
| empty_audio = 'empty.flac' | |
| sf.write(empty_audio, [], 16000) | |
| deuncase = gr.Interface.load("huggingface/pere/DeUnCaser", api_key=HF_AUTH_TOKEN) | |
| def generate_en(text, **kwargs): | |
| api_uri = "https://api.eleuther.ai/completion" | |
| #--data-raw '{"context":"Promtp","top_p":0.9,"temp":0.8,"response_length":128,"remove_input":true}' | |
| response = requests.post(api_uri, data=json.dumps({"context": text, "top_p": 0.9, "temp": 0.8, "response_length": kwargs["max_length"], "remove_input": True})) | |
| if response.ok: | |
| if DEBUG: | |
| print("English response >", response.json()) | |
| return response.json()[0]["generated_text"].lstrip() | |
| else: | |
| return "" | |
| def select_lang(lang): | |
| if lang.lower() == "spanish": | |
| return generate_es, transcribe_es, speak_es | |
| else: | |
| return generate_en, transcribe_en, speak_en | |
| def select_lang_vars(lang): | |
| if lang.lower() == "spanish": | |
| AGENT = "BERTIN" | |
| USER = "ENTREVISTADOR" | |
| CONTEXT = """La siguiente conversación es un extracto de una entrevista a {AGENT} celebrada en Madrid para Radio Televisión Española: | |
| {USER}: Bienvenido, {AGENT}. Un placer tenerlo hoy con nosotros. | |
| {AGENT}: Gracias. El placer es mío.""" | |
| else: | |
| AGENT = "ELEUTHER" | |
| USER = "INTERVIEWER" | |
| CONTEXT = """The next conversation is an excerpt from an interview to {AGENT} that appeared in the New York Times: | |
| {USER}: Welcome, {AGENT}. It is a pleasure to have you here today. | |
| {AGENT}: Thanks. The pleasure is mine.""" | |
| return AGENT, USER, CONTEXT | |
| def format_chat(history): | |
| interventions = [] | |
| for user, bot in history: | |
| interventions.append(f""" | |
| <div data-testid="user" style="background-color:#16a34a" class="px-3 py-2 rounded-[22px] rounded-bl-none place-self-start text-white ml-7 text-sm">{user}</div> | |
| <div data-testid="bot" style="background-color:gray" class="px-3 py-2 rounded-[22px] rounded-br-none text-white ml-7 text-sm">{bot}</div> | |
| """) | |
| return f"""<details><summary>Conversation log</summary> | |
| <div class="overflow-y-auto h-[40vh]"> | |
| <div class="flex flex-col items-end space-y-4 p-3"> | |
| {"".join(interventions)} | |
| </div> | |
| </div> | |
| </summary>""" | |
| def chat_with_gpt(lang, agent, user, context, audio_in, history): | |
| if not audio_in: | |
| return history, history, empty_audio, format_chat(history) | |
| generate, transcribe, speak = select_lang(lang) | |
| AGENT, USER, _ = select_lang_vars(lang) | |
| user_message = deuncase(transcribe(audio_in)) | |
| # agent = AGENT | |
| # user = USER | |
| generation_kwargs = { | |
| "max_length": 50, | |
| # "top_k": top_k, | |
| # "top_p": top_p, | |
| # "temperature": temperature, | |
| # "do_sample": do_sample, | |
| # "do_clean": do_clean, | |
| # "num_return_sequences": 1, | |
| # "return_full_text": False, | |
| } | |
| message = user_message.split(" ", 1)[0].capitalize() + " " + user_message.split(" ", 1)[-1] | |
| history = history or [] #[(f"{user}: Bienvenido. Encantado de tenerle con nosotros.", f"{agent}: Un placer, muchas gracias por la invitación.")] | |
| context = context.format(USER=user or USER, AGENT=agent or AGENT).strip() | |
| if context[-1] not in ".:": | |
| context += "." | |
| context_length = len(context.split()) | |
| history_take = 0 | |
| history_context = "\n".join(f"{user}: {history_message.capitalize()}.\n{agent}: {history_response}." for history_message, history_response in history[-len(history) + history_take:]) | |
| while len(history_context.split()) > MAX_LENGTH - (generation_kwargs["max_length"] + context_length): | |
| history_take += 1 | |
| history_context = "\n".join(f"{user}: {history_message.capitalize()}.\n{agent}: {history_response}." for history_message, history_response in history[-len(history) + history_take:]) | |
| if history_take >= MAX_LENGTH: | |
| break | |
| context += history_context | |
| for _ in range(5): | |
| prompt = f"{context}\n\n{user}: {message}.\n" | |
| response = generate(prompt, context_length=context_length, **generation_kwargs) | |
| if DEBUG: | |
| print("\n-----\n" + response + "\n-----\n") | |
| # response = response.split("\n")[-1] | |
| # if agent in response and response.split(agent)[-1]: | |
| # response = response.split(agent)[-1] | |
| # if user in response and response.split(user)[-1]: | |
| # response = response.split(user)[-1] | |
| # Take the first response | |
| response = [ | |
| r for r in response.replace(prompt, "").split(f"{AGENT}:") if r.strip() | |
| ][0].split(USER)[0].replace(f"{AGENT}:", "\n").strip() | |
| if response and response[0] in string.punctuation: | |
| response = response[1:].strip() | |
| if response.strip().startswith(f"{user}: {message}"): | |
| response = response.strip().split(f"{user}: {message}")[-1] | |
| if response.replace(".", "").strip() and message.replace(".", "").strip() != response.replace(".", "").strip(): | |
| break | |
| if DEBUG: | |
| print() | |
| print("CONTEXT:") | |
| print(context) | |
| print() | |
| print("MESSAGE") | |
| print(message) | |
| print() | |
| print("RESPONSE:") | |
| print(response) | |
| if not response.strip(): | |
| response = "Lo siento, no puedo hablar ahora" if lang.lower() == "Spanish" else "Sorry, can't talk right now" | |
| history.append((user_message, response)) | |
| return history, history, speak(response), format_chat(history) | |
| with gr.Blocks() as demo: | |
| gr.Markdown(HEADER) | |
| lang = gr.Radio(label="Language", choices=["English", "Spanish"], value=DEFAULT_LANG, type="value") | |
| AGENT, USER, CONTEXT = select_lang_vars(DEFAULT_LANG) | |
| context = gr.Textbox(label="Context", lines=5, value=CONTEXT) | |
| with gr.Row(): | |
| audio_in = gr.Audio(label="User", source="microphone", type="filepath") | |
| audio_out = gr.Audio(label="Agent", interactive=False, value=empty_audio) | |
| # chat_btn = gr.Button("Submit") | |
| with gr.Row(): | |
| user = gr.Textbox(label="User", value=USER) | |
| agent = gr.Textbox(label="Agent", value=AGENT) | |
| lang.change(select_lang_vars, inputs=[lang], outputs=[agent, user, context]) | |
| history = gr.Variable(value=[]) | |
| chatbot = gr.Variable() # gr.Chatbot(color_map=("green", "gray"), visible=False) | |
| # chat_btn.click(chat_with_gpt, inputs=[lang, agent, user, context, audio_in, history], outputs=[chatbot, history, audio_out]) | |
| log = gr.HTML() | |
| audio_in.change(chat_with_gpt, inputs=[lang, agent, user, context, audio_in, history], outputs=[chatbot, history, audio_out, log]) | |
| gr.Markdown(FOOTER) | |
| demo.launch() | |