File size: 3,961 Bytes
6b3973e
 
7272a1f
6b3973e
 
 
 
 
 
 
d704def
6b3973e
 
903da37
6b3973e
30a476d
 
 
 
 
 
6b3973e
 
7272a1f
6b3973e
 
 
 
 
 
 
 
 
 
 
 
30a476d
 
6b3973e
 
7272a1f
6b3973e
7272a1f
6b3973e
 
 
 
 
 
 
 
7272a1f
6b3973e
 
7272a1f
6b3973e
 
 
 
7272a1f
30a476d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7272a1f
30a476d
 
 
 
 
 
 
 
 
7272a1f
30a476d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7272a1f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# save as app.py
import threading
import gradio as gr
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TextIteratorStreamer,
)

MODEL_ID = "EpistemeAI/gpt-oss-20b-RL"

print("Loading tokenizer and model (this may take a while)...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Always use auto mapping / dtype
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype="auto",
    device_map="auto",
)
model.eval()
print("Model loaded. Device:", next(model.parameters()).device)

# --------- Helper: build prompt ----------
def build_prompt(system_message: str, history: list[dict], user_message: str) -> str:
    pieces = []
    if system_message:
        pieces.append(f"<|system|>\n{system_message}\n")
    for turn in history:
        role = turn.get("role", "user")
        content = turn.get("content", "")
        pieces.append(f"<|{role}|>\n{content}\n")
    pieces.append(f"<|user|>\n{user_message}\n<|assistant|>\n")
    return "\n".join(pieces)

# --------- Streaming generator ----------
def generate_stream(prompt, max_tokens, temperature, top_p):
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to(model.device)

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    gen_kwargs = dict(
        input_ids=input_ids,
        max_new_tokens=int(max_tokens),
        do_sample=True,
        temperature=float(temperature),
        top_p=float(top_p),
        streamer=streamer,
    )

    thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
    thread.start()

    partial = ""
    for token_str in streamer:
        partial += token_str
        yield partial

# --------- Gradio app logic ----------
def respond_stream(user_message, chat_history, system_message, max_tokens, temperature, top_p):
    history = chat_history or []
    prompt = build_prompt(system_message or "", history, user_message or "")

    history.append({"role": "user", "content": user_message})
    history.append({"role": "assistant", "content": ""})

    def history_to_chatbot_rows(hist):
        rows = []
        for item in hist:
            if item["role"] == "assistant":
                rows.append(("thinking...", item["content"] or "thinking..."))
        return rows or []

    chatbot_rows = history_to_chatbot_rows(history[:-1])
    chatbot_rows.append(("thinking...", "thinking..."))
    yield chatbot_rows  # placeholder row

    for partial in generate_stream(prompt, max_tokens, temperature, top_p):
        chatbot_rows[-1] = ("thinking...", partial)
        history[-1]["content"] = partial
        yield chatbot_rows

    chatbot_rows[-1] = ("thinking...", history[-1]["content"])
    yield chatbot_rows

# --------- Build Gradio UI ----------
with gr.Blocks() as demo:
    gr.Markdown(f"**Model:** {MODEL_ID}")
    with gr.Row():
        chatbot = gr.Chatbot(elem_id="chatbot", label="Assistant Output (user/system hidden)").style(height=500)

    history_state = gr.State(value=[])
    system_input = gr.Textbox(value="You are a Vibe Coder assistant.", label="System message")
    user_input = gr.Textbox(placeholder="Type a user message and press Send", label="Your message")
    max_tokens = gr.Slider(minimum=1, maximum=4000, value=800, step=1, label="Max new tokens")
    temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.01, label="Temperature")
    top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)")
    send_btn = gr.Button("Send")

    send_btn.click(
        fn=respond_stream,
        inputs=[user_input, history_state, system_input, max_tokens, temperature, top_p],
        outputs=[chatbot],
        queue=True,
    )

    send_btn.click(lambda u, s: s, inputs=[user_input, history_state], outputs=[history_state])
    send_btn.click(lambda: "", None, user_input)

if __name__ == "__main__":
    demo.launch()