from fastapi import FastAPI from fastapi.responses import StreamingResponse from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer import torch import os import uvicorn import threading app = FastAPI() # Load model & tokenizer sekali saat startup # MODEL_NAME = "Qwen/Qwen1.5-1.8B-Chat" MODEL_NAME = "Qwen/Qwen1.5-4B-Chat" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.bfloat16, device_map="cpu", trust_remote_code=True, ) model.config.use_cache = True # fallback kalau chat_template kosong if not tokenizer.chat_template: tokenizer.chat_template = """{% for message in messages %} {{ message['role'] }}: {{ message['content'] }} {% endfor %} assistant:""" # Request schema class Message(BaseModel): role: str content: str class ChatRequest(BaseModel): messages: list[Message] max_new_tokens: int = 128 # Generator untuk streaming token def generate_stream(prompt, max_new_tokens=128): inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # streamer = tokenizer.as_target_tokenizer() # # pakai generate incremental # with torch.no_grad(): # output_ids = model.generate( # **inputs, # max_new_tokens=max_new_tokens, # do_sample=True, # top_p=0.9, # temperature=0.7 # )[0] # # Ambil hasil tanpa input # generated_tokens = output_ids[inputs["input_ids"].shape[1]:] # for tok in generated_tokens: # text = tokenizer.decode(tok, skip_special_tokens=True) # if text.strip(): # yield text # streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) # generation_kwargs = dict( # **inputs, # max_new_tokens=max_new_tokens, # eos_token_id=tokenizer.eos_token_id, # do_sample=True, # temperature=0.7, # streamer=streamer # ) # thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) # thread.start() inputs = {k: v.to(model.device) for k, v in inputs.items()} streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) def run_generation(): try: model.generate( **inputs, max_new_tokens=max_new_tokens, # min_new_tokens=16, do_sample=True, temperature=0.7, top_p=0.9, streamer=streamer, # early_stopping=True, # eos_token_id=tokenizer.eos_token_id, # pad_token_id=tokenizer.pad_token_id, use_cache=True, ) except Exception as e: # simpan error agar bisa dikembalikan ke client setelah streamer selesai # error_container.append(str(e)) pass thread = threading.Thread(target=run_generation, daemon=True) thread.start() for token in streamer: yield token # streamer = tokenizer.as_target_tokenizer() # with torch.no_grad(): # output_ids = model.generate( # **inputs, # max_new_tokens=128, # batasi jawaban # min_new_tokens=16, # biar ga berhenti terlalu cepat # temperature=0.7, # lebih to the point # top_p=0.9, # do_sample=True, # early_stopping=True, # eos_token_id=tokenizer.eos_token_id, # pad_token_id=tokenizer.pad_token_id, # ) # decoded = tokenizer.decode(output_ids[0], skip_special_tokens=True) # if "Assistant:" in decoded: # answer = decoded.split("Assistant:")[-1].strip() # else: # answer = decoded # # stream potongan kalimat (kata demi kata) # for word in answer.split(): # yield word + " " @app.post("/stream") async def chat(req: ChatRequest): # Format prompt sesuai chat template text = tokenizer.apply_chat_template( req.messages, tokenize=False, add_generation_prompt=True ) generator = generate_stream(text, req.max_new_tokens) return StreamingResponse(generator, media_type="text/plain") @app.post("/chat") def chat(req: ChatRequest): text = tokenizer.apply_chat_template( [m.model_dump() for m in req.messages], tokenize=False, add_generation_prompt=True ) inputs = tokenizer(text, return_tensors="pt").to(model.device) outputs = model.generate( **inputs, max_new_tokens=req.max_new_tokens, do_sample=True, top_p=0.9, temperature=0.7 ) response = tokenizer.decode( # outputs[0][inputs["input_ids"]:].tolist(), outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True ) return {"response": response} @app.get("/") def root(): return {"message": "Qwen FastAPI running 🚀"} if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) uvicorn.run("app:app", host="0.0.0.0", port=port)