File size: 3,422 Bytes
6f78bf3
25d1b57
f8184cb
2b65d25
28de333
c2609dd
 
2b65d25
c2609dd
53ee96a
6f78bf3
28de333
809867d
66ca71d
28de333
f8184cb
28de333
25d1b57
4d9abbf
c2609dd
f8184cb
e4b129b
5383485
 
 
 
 
 
 
6f78bf3
28de333
5383485
 
 
 
f8184cb
5383485
f8184cb
6f78bf3
25d1b57
 
 
2b65d25
25d1b57
2b65d25
 
 
 
 
 
 
 
 
25d1b57
2b65d25
 
25d1b57
2b65d25
 
 
 
 
 
 
 
 
 
 
 
25d1b57
2b65d25
 
25d1b57
2b65d25
 
 
 
 
 
 
 
25d1b57
 
 
 
 
 
 
 
 
 
 
6f78bf3
f8184cb
e4b129b
25d1b57
e4b129b
 
 
 
4b3ff1b
28de333
 
 
 
 
 
 
 
5383485
7cd4b81
5383485
 
 
53ee96a
28de333
f8184cb
28de333
 
 
c2609dd
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
import os
import uvicorn
import threading

app = FastAPI()

# Load model & tokenizer sekali saat startup
# MODEL_NAME = "Qwen/Qwen1.5-1.8B-Chat"
MODEL_NAME = "Qwen/Qwen1.5-4B-Chat"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="cpu",
    trust_remote_code=True,
)
model.config.use_cache = True

# fallback kalau chat_template kosong
if not tokenizer.chat_template:
    tokenizer.chat_template = """{% for message in messages %}
{{ message['role'] }}: {{ message['content'] }}
{% endfor %}
assistant:"""

# Request schema
class Message(BaseModel):
    role: str
    content: str

class ChatRequest(BaseModel):
    messages: list[Message]
    max_new_tokens: int = 128

# Generator untuk streaming token
def generate_stream(prompt, max_new_tokens=128):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    # streamer = tokenizer.as_target_tokenizer()

    # # pakai generate incremental
    # with torch.no_grad():
    #     output_ids = model.generate(
    #         **inputs,
    #         max_new_tokens=max_new_tokens,
    #         do_sample=True,
    #         top_p=0.9,
    #         temperature=0.7
    #     )[0]

    # # Ambil hasil tanpa input
    # generated_tokens = output_ids[inputs["input_ids"].shape[1]:]

    # for tok in generated_tokens:
    #     text = tokenizer.decode(tok, skip_special_tokens=True)
    #     if text.strip():
    #         yield text
            
    streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
    generation_kwargs = dict(
        **inputs,
        max_new_tokens=max_new_tokens,
        temperature=0.7,
        streamer=streamer
    )

    thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    def token_stream():
        for token in streamer:
            yield token

    return StreamingResponse(token_stream(), media_type="text/plain")


@app.post("/stream")
async def chat(req: ChatRequest):
    # Format prompt sesuai chat template
    text = tokenizer.apply_chat_template(
        req.messages,
        tokenize=False,
        add_generation_prompt=True
    )

    generator = generate_stream(text, req.max_new_tokens)
    return StreamingResponse(generator, media_type="text/plain")

@app.post("/chat")
def chat(req: ChatRequest):
    text = tokenizer.apply_chat_template(
        [m.model_dump() for m in req.messages],
        tokenize=False,
        add_generation_prompt=True
    )
    inputs = tokenizer(text, return_tensors="pt").to(model.device)

    outputs = model.generate(
        **inputs,
        max_new_tokens=req.max_new_tokens,
        do_sample=True,
        top_p=0.9,
        temperature=0.7
    )

    response = tokenizer.decode(
        # outputs[0][inputs["input_ids"]:].tolist(),
        outputs[0][inputs["input_ids"].shape[1]:],
        skip_special_tokens=True
    )

    return {"response": response}

@app.get("/")
def root():
    return {"message": "Qwen FastAPI running 🚀"}

if __name__ == "__main__":
    port = int(os.environ.get("PORT", 7860))
    uvicorn.run("app:app", host="0.0.0.0", port=port)