Zai / app.py
huynhkimthien's picture
Update app.py
bda66ed verified
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import FileResponse
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import whisper
import torch
from gtts import gTTS
import os
import yt_dlp
import re
hf_token = os.getenv("HF_TOKEN")
app = FastAPI()
# Load Qwen model
model_name = "Qwen/Qwen3-4B-Instruct-2507"
tokenizer = AutoTokenizer.from_pretrained(model_name,token=hf_token)
model = AutoModelForCausalLM.from_pretrained(
model_name,
token=hf_token,
device_map={"": "cpu"},
dtype=torch.float32
)
# Load Whisper model
whisper_model = whisper.load_model("base")
# Lưu hội thoại
conversation = [{"role": "system", "content": "Bạn là một trợ lý AI. Hãy trả lời ngắn gọn, súc tích, tối đa 2 câu."}]
# Hàm trích xuất tên bài hát từ văn bản
def extract_song_name(text):
import re
match = re.search(r"(bài|bài hát|nghe nhạc|mở nhạc)\s+(.*)", text.lower())
if match:
return match.group(2).strip()
return None
def download_youtube_as_wav(song_name, output_path="song.wav"):
search_query = f"ytsearch1:{song_name}"
ydl_opts = {
'format': 'bestaudio/best',
'outtmpl': 'temp_audio.%(ext)s',
'postprocessors': [{
'key': 'FFmpegExtractAudio',
'preferredcodec': 'wav',
'preferredquality': '192',
}],
'quiet': True,
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
ydl.download([search_query])
if os.path.exists("temp_audio.wav"):
os.rename("temp_audio.wav", output_path)
return output_path
return None
class ChatRequest(BaseModel):
message: str
@app.get("/")
def read_root():
return {"message": "Ứng dụng đang chạy!"}
# Endpoint chat text
@app.post("/chat")
async def chat(request: ChatRequest):
conversation.append({"role": "user", "content": request.message})
text = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
response_text = generate_full_response(model_inputs)
conversation.append({"role": "assistant", "content": response_text})
return {"response": response_text}
# Endpoint voice chat + TTS
@app.post("/voice_chat")
async def voice_chat(file: UploadFile = File(...)):
file_location = f"temp_{file.filename}"
with open(file_location, "wb") as f:
f.write(await file.read())
result = whisper_model.transcribe(file_location, language="vi")
user_text = result["text"]
os.remove(file_location)
# Kiểm tra yêu cầu mở nhạc
if any(kw in user_text.lower() for kw in ["nghe nhạc", "mở bài hát", "bài hát", "bài"]):
song_name = extract_song_name(user_text)
if song_name:
wav_path = download_youtube_as_wav(song_name)
if wav_path:
return FileResponse(wav_path, media_type="audio/wav", filename="song.wav")
else:
return {"error": "Không tìm thấy hoặc tải được bài hát."}
# Nếu không phải yêu cầu mở nhạc → xử lý như cũ
conversation.append({"role": "user", "content": user_text})
text = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
response_text = generate_full_response(model_inputs)
conversation.append({"role": "assistant", "content": response_text})
tts = gTTS(response_text, lang="vi")
audio_file = "response.mp3"
tts.save(audio_file)
return {
"user_text": user_text,
"response": response_text,
"audio_url": f"/get_audio"
}
# Endpoint trả về file âm thanh
@app.get("/get_audio")
async def get_audio():
return FileResponse("response.mp3", media_type="audio/mpeg")
# Hàm sinh phản hồi
def generate_full_response(model_inputs, max_new_tokens=64):
with torch.inference_mode():
generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
response_text = tokenizer.decode(output_ids, skip_special_tokens=True)
return response_text.strip()