Spaces:
Running
Running
File size: 5,200 Bytes
f43e847 f9813cf 40d82fc 156b965 40d82fc 156b965 40d82fc bda66ed f9813cf 3a144a0 f9813cf bda66ed 156b965 da03d90 156b965 f9813cf 5c0fa7a d75e6c8 4e9a8cd b7dc424 4e9a8cd 5c0fa7a 78895c8 156b965 04a1d7e f9813cf bda66ed f9813cf bda66ed 40d82fc d86709d 40d82fc 202b2da 40d82fc 156b965 40d82fc b8e20b6 04a1d7e e95508c 156b965 da03d90 f9813cf df28e1d 1a94ea5 287106c 1a94ea5 287106c 1a94ea5 b9229bb 1a94ea5 f9813cf 156b965 b8e20b6 f9813cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
from fastapi import FastAPI, File, UploadFile, Request
from fastapi.responses import FileResponse
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import whisper
import torch
from gtts import gTTS
import os
import yt_dlp
#import re
import io
import numpy as np
import scipy.io.wavfile as wav
#from fastapi.responses import JSONResponse
hf_token = os.getenv("HF_TOKEN")
app = FastAPI()
# Load Qwen model
model_name = "Qwen/Qwen3-4B-Instruct-2507"
tokenizer = AutoTokenizer.from_pretrained(model_name,token=hf_token)
model = AutoModelForCausalLM.from_pretrained(
model_name,
token=hf_token,
device_map={"": "cpu"},
dtype=torch.float32
)
# Load Whisper model
whisper_model = whisper.load_model("base")
# Lưu hội thoại
conversation = [{"role": "system", "content": "Bạn là một trợ lý AI. Hãy trả lời ngắn gọn, súc tích, tối đa 2 câu."}]
# Hàm trích xuất tên bài hát từ văn bản
def extract_song_name(text):
import re
match = re.search(r"(bài|bài hát|nghe nhạc|mở nhạc)\s+(.*)", text.lower())
if match:
return match.group(2).strip()
return None
def download_youtube_as_wav(song_name, output_path="song.wav"):
search_query = f"ytsearch1:{song_name}"
ydl_opts = {
'format': 'bestaudio/best',
'outtmpl': 'temp_audio.%(ext)s',
'postprocessors': [{
'key': 'FFmpegExtractAudio',
'preferredcodec': 'wav',
'preferredquality': '192',
}],
'quiet': True,
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
ydl.download([search_query])
if os.path.exists("temp_audio.wav"):
os.rename("temp_audio.wav", output_path)
return output_path
return None
class ChatRequest(BaseModel):
message: str
@app.get("/")
def read_root():
return {"message": "Ứng dụng đang chạy!"}
# Endpoint chat text
@app.post("/chat")
async def chat(request: ChatRequest):
conversation.append({"role": "user", "content": request.message})
text = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
response_text = generate_full_response(model_inputs)
conversation.append({"role": "assistant", "content": response_text})
return {"response": response_text}
# Endpoint voice chat + TTS
@app.post("/voice_chat")
async def voice_chat(request: Request):
try:
raw_audio = await request.body()
sample_rate = 16000
# Chuyển từ 3 bytes → int32
audio_np = np.frombuffer(raw_audio, dtype=np.uint8).reshape(-1, 3)
audio_int = (audio_np[:, 0].astype(np.int32) << 16) | \
(audio_np[:, 1].astype(np.int32) << 8) | \
audio_np[:, 2].astype(np.int32)
# Scale về int16 để ghi WAV
audio_int16 = (audio_int >> 8).astype(np.int16)
# Chuyển thành WAV
wav_io = io.BytesIO()
wav.write(wav_io, sample_rate, audio_int16)
wav_io.seek(0)
with open("temp_audio.wav", "wb") as f:
f.write(wav_io.read())
# Whisper nhận dạng
result = whisper_model.transcribe("temp_audio.wav", language="vi")
user_text = result["text"]
# Kiểm tra yêu cầu mở nhạc
if any(kw in user_text.lower() for kw in ["nghe nhạc", "mở bài hát", "bài hát", "bài"]):
song_name = extract_song_name(user_text)
if song_name:
wav_path = download_youtube_as_wav(song_name)
if wav_path:
return FileResponse(wav_path, media_type="audio/wav", filename="song.wav")
else:
return JSONResponse({"error": "Không tìm thấy hoặc tải được bài hát."}, status_code=404)
# Xử lý hội thoại
conversation.append({"role": "user", "content": user_text})
text = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
response_text = generate_full_response(model_inputs)
conversation.append({"role": "assistant", "content": response_text})
# TTS
tts = gTTS(response_text, lang="vi")
audio_file = "response.mp3"
tts.save(audio_file)
return {
"user_text": user_text,
"response": response_text,
"audio_url": "/get_audio"
}
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
# Endpoint trả về file âm thanh
@app.get("/get_audio")
async def get_audio():
return FileResponse("response.mp3", media_type="audio/mpeg")
# Hàm sinh phản hồi
def generate_full_response(model_inputs, max_new_tokens=64):
with torch.inference_mode():
generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
response_text = tokenizer.decode(output_ids, skip_special_tokens=True)
return response_text.strip() |