Spaces:
Running
Running
| from fastapi import FastAPI, File, UploadFile | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import whisper | |
| import torch | |
| from gtts import gTTS | |
| import os | |
| import yt_dlp | |
| import re | |
| hf_token = os.getenv("HF_TOKEN") | |
| app = FastAPI() | |
| # Load Qwen model | |
| model_name = "Qwen/Qwen3-4B-Instruct-2507" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name,token=hf_token) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| token=hf_token, | |
| device_map={"": "cpu"}, | |
| dtype=torch.float32 | |
| ) | |
| # Load Whisper model | |
| whisper_model = whisper.load_model("base") | |
| # Lưu hội thoại | |
| conversation = [{"role": "system", "content": "Bạn là một trợ lý AI. Hãy trả lời ngắn gọn, súc tích, tối đa 2 câu."}] | |
| # Hàm trích xuất tên bài hát từ văn bản | |
| def extract_song_name(text): | |
| import re | |
| match = re.search(r"(bài|bài hát|nghe nhạc|mở nhạc)\s+(.*)", text.lower()) | |
| if match: | |
| return match.group(2).strip() | |
| return None | |
| def download_youtube_as_wav(song_name, output_path="song.wav"): | |
| search_query = f"ytsearch1:{song_name}" | |
| ydl_opts = { | |
| 'format': 'bestaudio/best', | |
| 'outtmpl': 'temp_audio.%(ext)s', | |
| 'postprocessors': [{ | |
| 'key': 'FFmpegExtractAudio', | |
| 'preferredcodec': 'wav', | |
| 'preferredquality': '192', | |
| }], | |
| 'quiet': True, | |
| } | |
| with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
| ydl.download([search_query]) | |
| if os.path.exists("temp_audio.wav"): | |
| os.rename("temp_audio.wav", output_path) | |
| return output_path | |
| return None | |
| class ChatRequest(BaseModel): | |
| message: str | |
| def read_root(): | |
| return {"message": "Ứng dụng đang chạy!"} | |
| # Endpoint chat text | |
| async def chat(request: ChatRequest): | |
| conversation.append({"role": "user", "content": request.message}) | |
| text = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) | |
| model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
| response_text = generate_full_response(model_inputs) | |
| conversation.append({"role": "assistant", "content": response_text}) | |
| return {"response": response_text} | |
| # Endpoint voice chat + TTS | |
| async def voice_chat(file: UploadFile = File(...)): | |
| file_location = f"temp_{file.filename}" | |
| with open(file_location, "wb") as f: | |
| f.write(await file.read()) | |
| result = whisper_model.transcribe(file_location, language="vi") | |
| user_text = result["text"] | |
| os.remove(file_location) | |
| # Kiểm tra yêu cầu mở nhạc | |
| if any(kw in user_text.lower() for kw in ["nghe nhạc", "mở bài hát", "bài hát", "bài"]): | |
| song_name = extract_song_name(user_text) | |
| if song_name: | |
| wav_path = download_youtube_as_wav(song_name) | |
| if wav_path: | |
| return FileResponse(wav_path, media_type="audio/wav", filename="song.wav") | |
| else: | |
| return {"error": "Không tìm thấy hoặc tải được bài hát."} | |
| # Nếu không phải yêu cầu mở nhạc → xử lý như cũ | |
| conversation.append({"role": "user", "content": user_text}) | |
| text = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) | |
| model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
| response_text = generate_full_response(model_inputs) | |
| conversation.append({"role": "assistant", "content": response_text}) | |
| tts = gTTS(response_text, lang="vi") | |
| audio_file = "response.mp3" | |
| tts.save(audio_file) | |
| return { | |
| "user_text": user_text, | |
| "response": response_text, | |
| "audio_url": f"/get_audio" | |
| } | |
| # Endpoint trả về file âm thanh | |
| async def get_audio(): | |
| return FileResponse("response.mp3", media_type="audio/mpeg") | |
| # Hàm sinh phản hồi | |
| def generate_full_response(model_inputs, max_new_tokens=64): | |
| with torch.inference_mode(): | |
| generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens) | |
| output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() | |
| response_text = tokenizer.decode(output_ids, skip_special_tokens=True) | |
| return response_text.strip() |