legolasyiu's picture
Update app.py
30a476d verified
raw
history blame
3.96 kB
# save as app.py
import threading
import gradio as gr
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TextIteratorStreamer,
)
MODEL_ID = "EpistemeAI/gpt-oss-20b-RL"
print("Loading tokenizer and model (this may take a while)...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# Always use auto mapping / dtype
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype="auto",
device_map="auto",
)
model.eval()
print("Model loaded. Device:", next(model.parameters()).device)
# --------- Helper: build prompt ----------
def build_prompt(system_message: str, history: list[dict], user_message: str) -> str:
pieces = []
if system_message:
pieces.append(f"<|system|>\n{system_message}\n")
for turn in history:
role = turn.get("role", "user")
content = turn.get("content", "")
pieces.append(f"<|{role}|>\n{content}\n")
pieces.append(f"<|user|>\n{user_message}\n<|assistant|>\n")
return "\n".join(pieces)
# --------- Streaming generator ----------
def generate_stream(prompt, max_tokens, temperature, top_p):
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(model.device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = dict(
input_ids=input_ids,
max_new_tokens=int(max_tokens),
do_sample=True,
temperature=float(temperature),
top_p=float(top_p),
streamer=streamer,
)
thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
partial = ""
for token_str in streamer:
partial += token_str
yield partial
# --------- Gradio app logic ----------
def respond_stream(user_message, chat_history, system_message, max_tokens, temperature, top_p):
history = chat_history or []
prompt = build_prompt(system_message or "", history, user_message or "")
history.append({"role": "user", "content": user_message})
history.append({"role": "assistant", "content": ""})
def history_to_chatbot_rows(hist):
rows = []
for item in hist:
if item["role"] == "assistant":
rows.append(("thinking...", item["content"] or "thinking..."))
return rows or []
chatbot_rows = history_to_chatbot_rows(history[:-1])
chatbot_rows.append(("thinking...", "thinking..."))
yield chatbot_rows # placeholder row
for partial in generate_stream(prompt, max_tokens, temperature, top_p):
chatbot_rows[-1] = ("thinking...", partial)
history[-1]["content"] = partial
yield chatbot_rows
chatbot_rows[-1] = ("thinking...", history[-1]["content"])
yield chatbot_rows
# --------- Build Gradio UI ----------
with gr.Blocks() as demo:
gr.Markdown(f"**Model:** {MODEL_ID}")
with gr.Row():
chatbot = gr.Chatbot(elem_id="chatbot", label="Assistant Output (user/system hidden)").style(height=500)
history_state = gr.State(value=[])
system_input = gr.Textbox(value="You are a Vibe Coder assistant.", label="System message")
user_input = gr.Textbox(placeholder="Type a user message and press Send", label="Your message")
max_tokens = gr.Slider(minimum=1, maximum=4000, value=800, step=1, label="Max new tokens")
temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.01, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)")
send_btn = gr.Button("Send")
send_btn.click(
fn=respond_stream,
inputs=[user_input, history_state, system_input, max_tokens, temperature, top_p],
outputs=[chatbot],
queue=True,
)
send_btn.click(lambda u, s: s, inputs=[user_input, history_state], outputs=[history_state])
send_btn.click(lambda: "", None, user_input)
if __name__ == "__main__":
demo.launch()