qwen_api / app.py
aryo100's picture
updare app & requirements
2b65d25
raw
history blame
3.42 kB
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
import os
import uvicorn
import threading
app = FastAPI()
# Load model & tokenizer sekali saat startup
# MODEL_NAME = "Qwen/Qwen1.5-1.8B-Chat"
MODEL_NAME = "Qwen/Qwen1.5-4B-Chat"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16,
device_map="cpu",
trust_remote_code=True,
)
model.config.use_cache = True
# fallback kalau chat_template kosong
if not tokenizer.chat_template:
tokenizer.chat_template = """{% for message in messages %}
{{ message['role'] }}: {{ message['content'] }}
{% endfor %}
assistant:"""
# Request schema
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: list[Message]
max_new_tokens: int = 128
# Generator untuk streaming token
def generate_stream(prompt, max_new_tokens=128):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# streamer = tokenizer.as_target_tokenizer()
# # pakai generate incremental
# with torch.no_grad():
# output_ids = model.generate(
# **inputs,
# max_new_tokens=max_new_tokens,
# do_sample=True,
# top_p=0.9,
# temperature=0.7
# )[0]
# # Ambil hasil tanpa input
# generated_tokens = output_ids[inputs["input_ids"].shape[1]:]
# for tok in generated_tokens:
# text = tokenizer.decode(tok, skip_special_tokens=True)
# if text.strip():
# yield text
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
generation_kwargs = dict(
**inputs,
max_new_tokens=max_new_tokens,
temperature=0.7,
streamer=streamer
)
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
def token_stream():
for token in streamer:
yield token
return StreamingResponse(token_stream(), media_type="text/plain")
@app.post("/stream")
async def chat(req: ChatRequest):
# Format prompt sesuai chat template
text = tokenizer.apply_chat_template(
req.messages,
tokenize=False,
add_generation_prompt=True
)
generator = generate_stream(text, req.max_new_tokens)
return StreamingResponse(generator, media_type="text/plain")
@app.post("/chat")
def chat(req: ChatRequest):
text = tokenizer.apply_chat_template(
[m.model_dump() for m in req.messages],
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(text, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=req.max_new_tokens,
do_sample=True,
top_p=0.9,
temperature=0.7
)
response = tokenizer.decode(
# outputs[0][inputs["input_ids"]:].tolist(),
outputs[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True
)
return {"response": response}
@app.get("/")
def root():
return {"message": "Qwen FastAPI running πŸš€"}
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run("app:app", host="0.0.0.0", port=port)