File size: 5,200 Bytes
f43e847
f9813cf
40d82fc
 
156b965
40d82fc
156b965
40d82fc
bda66ed
f9813cf
3a144a0
f9813cf
 
 
bda66ed
156b965
da03d90
 
156b965
 
f9813cf
5c0fa7a
d75e6c8
4e9a8cd
b7dc424
4e9a8cd
5c0fa7a
78895c8
156b965
 
 
 
 
04a1d7e
f9813cf
bda66ed
f9813cf
bda66ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40d82fc
 
d86709d
40d82fc
202b2da
40d82fc
 
156b965
40d82fc
b8e20b6
04a1d7e
e95508c
 
156b965
 
 
da03d90
f9813cf
df28e1d
1a94ea5
 
 
 
287106c
 
 
 
 
 
 
 
 
1a94ea5
 
 
287106c
1a94ea5
 
 
b9229bb
1a94ea5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9813cf
156b965
b8e20b6
f9813cf
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from fastapi import FastAPI, File, UploadFile, Request
from fastapi.responses import FileResponse
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import whisper
import torch
from gtts import gTTS
import os
import yt_dlp
#import re
import io
import numpy as np 
import scipy.io.wavfile as wav 
#from fastapi.responses import JSONResponse 

hf_token = os.getenv("HF_TOKEN")
app = FastAPI()

# Load Qwen model
model_name = "Qwen/Qwen3-4B-Instruct-2507"
tokenizer = AutoTokenizer.from_pretrained(model_name,token=hf_token)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    token=hf_token,
    device_map={"": "cpu"},
    dtype=torch.float32
)

# Load Whisper model
whisper_model = whisper.load_model("base")

# Lưu hội thoại
conversation = [{"role": "system", "content": "Bạn là một trợ lý AI. Hãy trả lời ngắn gọn, súc tích, tối đa 2 câu."}]

# Hàm trích xuất tên bài hát từ văn bản
def extract_song_name(text):
    import re
    match = re.search(r"(bài|bài hát|nghe nhạc|mở nhạc)\s+(.*)", text.lower())
    if match:
        return match.group(2).strip()
    return None

def download_youtube_as_wav(song_name, output_path="song.wav"):
    search_query = f"ytsearch1:{song_name}"
    ydl_opts = {
        'format': 'bestaudio/best',
        'outtmpl': 'temp_audio.%(ext)s',
        'postprocessors': [{
            'key': 'FFmpegExtractAudio',
            'preferredcodec': 'wav',
            'preferredquality': '192',
        }],
        'quiet': True,
    }
    with yt_dlp.YoutubeDL(ydl_opts) as ydl:
        ydl.download([search_query])
    if os.path.exists("temp_audio.wav"):
        os.rename("temp_audio.wav", output_path)
        return output_path
    return None

class ChatRequest(BaseModel):
    message: str

@app.get("/")
def read_root():
    return {"message": "Ứng dụng đang chạy!"}

# Endpoint chat text
@app.post("/chat")
async def chat(request: ChatRequest):
    conversation.append({"role": "user", "content": request.message})
    text = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    response_text = generate_full_response(model_inputs)
    conversation.append({"role": "assistant", "content": response_text})
    return {"response": response_text}

# Endpoint voice chat + TTS
@app.post("/voice_chat")
async def voice_chat(request: Request):
    try:
        raw_audio = await request.body()
        sample_rate = 16000

        # Chuyển từ 3 bytes → int32
        audio_np = np.frombuffer(raw_audio, dtype=np.uint8).reshape(-1, 3)
        audio_int = (audio_np[:, 0].astype(np.int32) << 16) | \
            (audio_np[:, 1].astype(np.int32) << 8) | \
            audio_np[:, 2].astype(np.int32)

        # Scale về int16 để ghi WAV
        audio_int16 = (audio_int >> 8).astype(np.int16)

        # Chuyển thành WAV
        wav_io = io.BytesIO()
        wav.write(wav_io, sample_rate, audio_int16)
        wav_io.seek(0)

        with open("temp_audio.wav", "wb") as f:
            f.write(wav_io.read())

        # Whisper nhận dạng
        result = whisper_model.transcribe("temp_audio.wav", language="vi")
        user_text = result["text"]

        # Kiểm tra yêu cầu mở nhạc
        if any(kw in user_text.lower() for kw in ["nghe nhạc", "mở bài hát", "bài hát", "bài"]):
            song_name = extract_song_name(user_text)
            if song_name:
                wav_path = download_youtube_as_wav(song_name)
                if wav_path:
                    return FileResponse(wav_path, media_type="audio/wav", filename="song.wav")
                else:
                    return JSONResponse({"error": "Không tìm thấy hoặc tải được bài hát."}, status_code=404)

        # Xử lý hội thoại
        conversation.append({"role": "user", "content": user_text})
        text = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
        model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
        response_text = generate_full_response(model_inputs)
        conversation.append({"role": "assistant", "content": response_text})

        # TTS
        tts = gTTS(response_text, lang="vi")
        audio_file = "response.mp3"
        tts.save(audio_file)

        return {
            "user_text": user_text,
            "response": response_text,
            "audio_url": "/get_audio"
        }

    except Exception as e:
        return JSONResponse({"error": str(e)}, status_code=500)
# Endpoint trả về file âm thanh
@app.get("/get_audio")
async def get_audio():
    return FileResponse("response.mp3", media_type="audio/mpeg")

# Hàm sinh phản hồi
def generate_full_response(model_inputs, max_new_tokens=64):
    with torch.inference_mode():
        generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens)
    output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
    response_text = tokenizer.decode(output_ids, skip_special_tokens=True)
    return response_text.strip()