qwen_api / app.py
aryo100's picture
update app
53ee96a
raw
history blame
1.31 kB
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch, os, uvicorn
app = FastAPI()
model_name = "Qwen/Qwen-1_8B-Chat" # ganti sesuai ukuran
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else "cpu"
)
class ChatRequest(BaseModel):
prompt: str
max_new_tokens: int = 128
@app.post("/chat")
def chat(req: ChatRequest):
# Format percakapan sesuai template Qwen
messages = [
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": req.prompt},
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=req.max_new_tokens)
reply = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"reply": reply}
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run("app:app", host="0.0.0.0", port=port)