Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -10,13 +10,17 @@ import openai
|
|
| 10 |
from llama_index import GPTVectorStoreIndex, SimpleDirectoryReader, Prompt
|
| 11 |
from transformers import pipeline
|
| 12 |
import opencc
|
|
|
|
|
|
|
| 13 |
|
| 14 |
converter = opencc.OpenCC('t2s') # 创建一个OpenCC实例,指定繁体字转为简体字
|
| 15 |
ocr = CnOcr() # 初始化ocr模型
|
| 16 |
history_max_len = 500 # 机器人记忆的最大长度
|
| 17 |
all_max_len = 2000 # 输入的最大长度
|
| 18 |
asr_model_id = "openai/whisper-tiny" # 更新为你的模型ID
|
| 19 |
-
|
|
|
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
def get_text_emb(open_ai_key, text): # 文本向量化
|
|
@@ -140,11 +144,23 @@ def get_response_by_llama_index(open_ai_key, msg, bot, query_engine): # 获取
|
|
| 140 |
return bot[max(0, len(bot) - 3):] # 返回最近3轮的历史记录
|
| 141 |
|
| 142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
def get_response(open_ai_key, msg, bot, doc_text_list, doc_embeddings, query_engine, index_type): # 获取机器人回复
|
| 144 |
if index_type == 1: # 如果是使用自己的索引
|
| 145 |
-
|
| 146 |
else: # 如果是使用llama_index索引
|
| 147 |
-
|
|
|
|
|
|
|
| 148 |
|
| 149 |
|
| 150 |
def up_file(files): # 上传文件
|
|
@@ -235,6 +251,7 @@ with gr.Blocks() as demo:
|
|
| 235 |
with gr.Column():
|
| 236 |
md = gr.Markdown("""操作说明 step 1:点击左侧区域,上传PDF,进行解析""") # 操作说明
|
| 237 |
chat_bot = gr.Chatbot(visible=False) # 聊天机器人
|
|
|
|
| 238 |
with gr.Row():
|
| 239 |
asr_type = gr.Radio(value='self', choices=['self', 'openai'], label='语音识别方式', visible=False) # 语音识别方式
|
| 240 |
audio_inputs = gr.Audio(source="microphone", type="filepath", label="点击录音输入", visible=False) # 录音输入
|
|
@@ -250,7 +267,7 @@ with gr.Blocks() as demo:
|
|
| 250 |
audio_inputs.change(transcribe_speech, [open_ai_key, audio_inputs, asr_type], [msg_txt]) # 录音输入
|
| 251 |
chat_bu.click(get_response,
|
| 252 |
[open_ai_key, msg_txt, chat_bot, doc_text_state, doc_emb_state, query_engine, index_type],
|
| 253 |
-
[chat_bot]) # 发送消息
|
| 254 |
|
| 255 |
if __name__ == "__main__":
|
| 256 |
-
demo.queue().launch()
|
|
|
|
| 10 |
from llama_index import GPTVectorStoreIndex, SimpleDirectoryReader, Prompt
|
| 11 |
from transformers import pipeline
|
| 12 |
import opencc
|
| 13 |
+
import scipy
|
| 14 |
+
import torch
|
| 15 |
|
| 16 |
converter = opencc.OpenCC('t2s') # 创建一个OpenCC实例,指定繁体字转为简体字
|
| 17 |
ocr = CnOcr() # 初始化ocr模型
|
| 18 |
history_max_len = 500 # 机器人记忆的最大长度
|
| 19 |
all_max_len = 2000 # 输入的最大长度
|
| 20 |
asr_model_id = "openai/whisper-tiny" # 更新为你的模型ID
|
| 21 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 22 |
+
asr_pipe = pipeline("automatic-speech-recognition", model=asr_model_id, device=device)
|
| 23 |
+
synthesiser = pipeline("text-to-speech", "suno/bark-small", device=device)
|
| 24 |
|
| 25 |
|
| 26 |
def get_text_emb(open_ai_key, text): # 文本向量化
|
|
|
|
| 144 |
return bot[max(0, len(bot) - 3):] # 返回最近3轮的历史记录
|
| 145 |
|
| 146 |
|
| 147 |
+
import hashlib
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def get_audio_answer(answer): # 获取语音回答
|
| 151 |
+
speech = synthesiser(answer, forward_params={"do_sample": True}) # 生成语音
|
| 152 |
+
md5 = hashlib.md5(answer.encode('utf-8')).hexdigest() # 获取md5
|
| 153 |
+
scipy.io.wavfile.write("{}.wav".format(md5), rate=speech["sampling_rate"], data=speech["audio"]) # 保存语音
|
| 154 |
+
return "{}.wav".format(md5)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
def get_response(open_ai_key, msg, bot, doc_text_list, doc_embeddings, query_engine, index_type): # 获取机器人回复
|
| 158 |
if index_type == 1: # 如果是使用自己的索引
|
| 159 |
+
bot = get_response_by_self(open_ai_key, msg, bot, doc_text_list, doc_embeddings)
|
| 160 |
else: # 如果是使用llama_index索引
|
| 161 |
+
bot = get_response_by_llama_index(open_ai_key, msg, bot, query_engine)
|
| 162 |
+
audio_answer_dir = get_audio_answer(bot[-1][1]) # 获取语音回答
|
| 163 |
+
return bot, gr.Audio(audio_answer_dir)
|
| 164 |
|
| 165 |
|
| 166 |
def up_file(files): # 上传文件
|
|
|
|
| 251 |
with gr.Column():
|
| 252 |
md = gr.Markdown("""操作说明 step 1:点击左侧区域,上传PDF,进行解析""") # 操作说明
|
| 253 |
chat_bot = gr.Chatbot(visible=False) # 聊天机器人
|
| 254 |
+
audio_answer = gr.Audio() # 语音回答
|
| 255 |
with gr.Row():
|
| 256 |
asr_type = gr.Radio(value='self', choices=['self', 'openai'], label='语音识别方式', visible=False) # 语音识别方式
|
| 257 |
audio_inputs = gr.Audio(source="microphone", type="filepath", label="点击录音输入", visible=False) # 录音输入
|
|
|
|
| 267 |
audio_inputs.change(transcribe_speech, [open_ai_key, audio_inputs, asr_type], [msg_txt]) # 录音输入
|
| 268 |
chat_bu.click(get_response,
|
| 269 |
[open_ai_key, msg_txt, chat_bot, doc_text_state, doc_emb_state, query_engine, index_type],
|
| 270 |
+
[chat_bot, audio_answer]) # 发送消息
|
| 271 |
|
| 272 |
if __name__ == "__main__":
|
| 273 |
+
demo.queue(concurrency_count=4).launch()
|