legolasyiu commited on
Commit
30a476d
·
verified ·
1 Parent(s): d704def

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -54
app.py CHANGED
@@ -10,33 +10,20 @@ from transformers import (
10
 
11
  MODEL_ID = "EpistemeAI/gpt-oss-20b-RL"
12
 
13
- # --------- Model load (do this once at startup) ----------
14
- # Adjust dtype / device_map to your environment.
15
- # If you have limited GPU memory, consider: device_map="auto", load_in_8bit=True (requires bitsandbytes)
16
  print("Loading tokenizer and model (this may take a while)...")
17
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
18
 
19
- # Recommended: try device_map="auto" with accelerate installed; fallback to cpu if not available.
20
- try:
21
- model = AutoModelForCausalLM.from_pretrained(
22
- MODEL_ID,
23
- torch_dtype="auto",
24
- device_map="cuda",
25
- )
26
- except Exception as e:
27
- print("Automatic device_map load failed, falling back to cpu. Error:", e)
28
- model = AutoModelForCausalLM.from_pretrained(
29
- MODEL_ID,
30
- torch_dtype="auto",
31
- device_map="auto",
32
- )
33
-
34
  model.eval()
35
  print("Model loaded. Device:", next(model.parameters()).device)
36
 
37
  # --------- Helper: build prompt ----------
38
  def build_prompt(system_message: str, history: list[dict], user_message: str) -> str:
39
- # Keep your conversation structure — adapt to model's preferred format if needed.
40
  pieces = []
41
  if system_message:
42
  pieces.append(f"<|system|>\n{system_message}\n")
@@ -47,27 +34,11 @@ def build_prompt(system_message: str, history: list[dict], user_message: str) ->
47
  pieces.append(f"<|user|>\n{user_message}\n<|assistant|>\n")
48
  return "\n".join(pieces)
49
 
50
- # --------- Gradio respond function (streams tokens) ----------
51
- def respond(
52
- message,
53
- history: list[dict],
54
- system_message,
55
- max_tokens,
56
- temperature,
57
- top_p,
58
- hf_token=None, # kept for compatibility with UI; not used for local pipeline
59
- ):
60
- """
61
- Streams tokens as they are generated using TextIteratorStreamer.
62
- Gradio will accept a generator yielding partial response strings.
63
- """
64
- prompt = build_prompt(system_message, history or [], message)
65
-
66
- # Prepare inputs
67
  inputs = tokenizer(prompt, return_tensors="pt")
68
  input_ids = inputs["input_ids"].to(model.device)
69
 
70
- # Create streamer to yield token-chunks as they are generated
71
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
72
 
73
  gen_kwargs = dict(
@@ -79,33 +50,64 @@ def respond(
79
  streamer=streamer,
80
  )
81
 
82
- # Start generation in background thread
83
  thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
84
  thread.start()
85
 
86
  partial = ""
87
- # Iterate streamer yields token chunks (strings)
88
  for token_str in streamer:
89
  partial += token_str
90
  yield partial
91
 
92
- # --------- Build Gradio UI ----------
93
- chatbot = gr.ChatInterface(
94
- respond,
95
- type="messages",
96
- additional_inputs=[
97
- gr.Textbox(value="You are a Vibe Coder assistant.", label="System message"),
98
- gr.Slider(minimum=1, maximum=4000, value=2000, step=1, label="Max new tokens"),
99
- gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.01, label="Temperature"),
100
- gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)"),
101
- ],
102
- )
 
 
 
 
 
 
 
103
 
 
 
 
 
 
 
 
 
 
104
  with gr.Blocks() as demo:
105
- with gr.Sidebar():
106
- gr.Markdown("Model: " + MODEL_ID)
107
- gr.LoginButton()
108
- chatbot.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  if __name__ == "__main__":
111
  demo.launch()
 
10
 
11
  MODEL_ID = "EpistemeAI/gpt-oss-20b-RL"
12
 
 
 
 
13
  print("Loading tokenizer and model (this may take a while)...")
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
15
 
16
+ # Always use auto mapping / dtype
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ MODEL_ID,
19
+ torch_dtype="auto",
20
+ device_map="auto",
21
+ )
 
 
 
 
 
 
 
 
 
22
  model.eval()
23
  print("Model loaded. Device:", next(model.parameters()).device)
24
 
25
  # --------- Helper: build prompt ----------
26
  def build_prompt(system_message: str, history: list[dict], user_message: str) -> str:
 
27
  pieces = []
28
  if system_message:
29
  pieces.append(f"<|system|>\n{system_message}\n")
 
34
  pieces.append(f"<|user|>\n{user_message}\n<|assistant|>\n")
35
  return "\n".join(pieces)
36
 
37
+ # --------- Streaming generator ----------
38
+ def generate_stream(prompt, max_tokens, temperature, top_p):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  inputs = tokenizer(prompt, return_tensors="pt")
40
  input_ids = inputs["input_ids"].to(model.device)
41
 
 
42
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
43
 
44
  gen_kwargs = dict(
 
50
  streamer=streamer,
51
  )
52
 
 
53
  thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
54
  thread.start()
55
 
56
  partial = ""
 
57
  for token_str in streamer:
58
  partial += token_str
59
  yield partial
60
 
61
+ # --------- Gradio app logic ----------
62
+ def respond_stream(user_message, chat_history, system_message, max_tokens, temperature, top_p):
63
+ history = chat_history or []
64
+ prompt = build_prompt(system_message or "", history, user_message or "")
65
+
66
+ history.append({"role": "user", "content": user_message})
67
+ history.append({"role": "assistant", "content": ""})
68
+
69
+ def history_to_chatbot_rows(hist):
70
+ rows = []
71
+ for item in hist:
72
+ if item["role"] == "assistant":
73
+ rows.append(("thinking...", item["content"] or "thinking..."))
74
+ return rows or []
75
+
76
+ chatbot_rows = history_to_chatbot_rows(history[:-1])
77
+ chatbot_rows.append(("thinking...", "thinking..."))
78
+ yield chatbot_rows # placeholder row
79
 
80
+ for partial in generate_stream(prompt, max_tokens, temperature, top_p):
81
+ chatbot_rows[-1] = ("thinking...", partial)
82
+ history[-1]["content"] = partial
83
+ yield chatbot_rows
84
+
85
+ chatbot_rows[-1] = ("thinking...", history[-1]["content"])
86
+ yield chatbot_rows
87
+
88
+ # --------- Build Gradio UI ----------
89
  with gr.Blocks() as demo:
90
+ gr.Markdown(f"**Model:** {MODEL_ID}")
91
+ with gr.Row():
92
+ chatbot = gr.Chatbot(elem_id="chatbot", label="Assistant Output (user/system hidden)").style(height=500)
93
+
94
+ history_state = gr.State(value=[])
95
+ system_input = gr.Textbox(value="You are a Vibe Coder assistant.", label="System message")
96
+ user_input = gr.Textbox(placeholder="Type a user message and press Send", label="Your message")
97
+ max_tokens = gr.Slider(minimum=1, maximum=4000, value=800, step=1, label="Max new tokens")
98
+ temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.01, label="Temperature")
99
+ top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)")
100
+ send_btn = gr.Button("Send")
101
+
102
+ send_btn.click(
103
+ fn=respond_stream,
104
+ inputs=[user_input, history_state, system_input, max_tokens, temperature, top_p],
105
+ outputs=[chatbot],
106
+ queue=True,
107
+ )
108
+
109
+ send_btn.click(lambda u, s: s, inputs=[user_input, history_state], outputs=[history_state])
110
+ send_btn.click(lambda: "", None, user_input)
111
 
112
  if __name__ == "__main__":
113
  demo.launch()