talk2ai / app.py
alonsosilva's picture
Make sidebar text smaller
32666ce
#from dotenv import load_dotenv, find_dotenv
#_ = load_dotenv(find_dotenv())
import solara
from typing import List
from typing_extensions import TypedDict
class MessageDict(TypedDict):
role: str
content: str
from tempfile import NamedTemporaryFile
import ipywebrtc
from ipywebrtc import AudioRecorder, CameraStream, AudioStream
import ipywidgets
from whispercpp import Whisper
from elevenlabs import voices, generate, save, set_api_key
from gtts import gTTS
w = Whisper('tiny')
import os
os.environ['OPENAI_API_BASE'] = "https://shale.live/v1"
os.environ['OPENAI_API_KEY'] = os.getenv("SHALE_API_KEY")
set_api_key(os.getenv("ELEVENLABS_API_KEY"))
from langchain.chat_models import ChatOpenAI
llm = ChatOpenAI(temperature=0.7)
from IPython.display import display
messages: solara.Reactive[List[MessageDict]] = solara.reactive([])
user_message_count = solara.reactive(0)
resultado = solara.reactive("")
@solara.component
@solara.component
def Page():
with solara.Column(style={"padding": "30px"}):
solara.Title("Talk to Llama2")
solara.Markdown("#Talk to Llama2")
user_message_count.value = len([m for m in messages.value if m["role"] == "user"])
def send(message):
messages.value = [
*messages.value,
{"role": "user", "content": message},
]
def response(message):
messages.value = [
*messages.value,
{"role": "assistant", "content": llm.predict(message)}
]
def result():
if messages.value !=[]: response(messages.value[-1]["content"])
generated_response = solara.use_thread(result, [user_message_count.value])
solara.ProgressLinear(generated_response.state == solara.ResultState.RUNNING)
with solara.Column(style={"width": "70%"}):
with solara.Sidebar():
solara.Markdown("## Send a voice message")
solara.Markdown("### Recorder")
camera = CameraStream(constraints={'audio': True,'video':False})
recorder = AudioRecorder(stream=camera)
display(recorder)
def MyButton():
def transcribe_voice():
with NamedTemporaryFile(suffix=".webm") as temp:
with open(f"{temp.name}", 'wb') as f:
f.write(recorder.audio.value)
result = w.transcribe(f"{temp.name}", lang="en")
text = w.extract_text(result)
resultado.value = text[0]
messages.value = [*messages.value, {"role": "user", "content": f"{resultado.value}"},]
solara.Button("Send voice message", on_click=transcribe_voice)
MyButton()
if resultado.value != "":
solara.Markdown("## Transcribed message:")
solara.Text(f"{resultado.value}", style={"color":"blue"})
with solara.lab.ChatBox():
for counter, item in enumerate(messages.value):
with solara.lab.ChatMessage(
user=item["role"] == "user",
name="Assistant" if item["role"] == "assistant" else "User",
avatar_background_color="#33cccc" if item["role"] == "assistant" else "#ff991f",
border_radius="20px",
):
solara.Markdown(item["content"])
if counter == 2*user_message_count.value-1:
tts = gTTS(f"{messages.value[-1]['content']}")
with NamedTemporaryFile(suffix=".mp3") as temp:
tts.save(f"{temp.name}")
audio = ipywidgets.Audio.from_file(filename=f"{temp.name}", autoplay=True, loop=False)
if generated_response.state != solara.ResultState.RUNNING:
display(audio)
# voice_name = "Adam"
# model_name = "eleven_monolingual_v1"
# audio = generate(text=f"{messages.value[-1]['content']}", voice=voice_name, model=model_name)
# with NamedTemporaryFile(suffix=".mp3") as temp:
# save(audio, f"{temp.name}")
# audio = ipywidgets.Audio.from_file(filename=f"{temp.name}", autoplay=False, loop=False)
# display(audio)
solara.lab.ChatInput(send_callback=send, disabled=(generated_response.state == solara.ResultState.RUNNING))