File size: 3,788 Bytes
4c1a418
97d5f6c
c971648
 
8b979ce
c971648
7b98df3
c971648
 
ee884f8
c971648
 
e287708
c971648
 
 
 
 
 
 
 
 
 
 
 
 
7272a1f
c971648
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e287708
c971648
 
 
 
 
 
 
 
 
 
 
4723961
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
import torch
import time

# --- Model / tokenizer load (your checkpoint) ---
checkpoint = "EpistemeAI/metatune-gpt20b-R0"
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype="auto").to(device)
model.eval()

# --- Helper: convert gradio display history (tuples) -> model/chat history (dicts) ---
def display_to_model_history(display_history):
    """
    Convert gradio chatbot history (list of (role, text)) into a list of dicts
    used by your tokenizer.apply_chat_template. Adjust roles to 'user'/'assistant'.
    """
    model_history = []
    if not display_history:
        return model_history
    for role, text in display_history:
        role_key = "user" if role.lower().startswith("user") else "assistant"
        model_history.append({"role": role_key, "content": text})
    return model_history

# --- Prediction (generator) that shows thinking and then final output ---
def predict(user_message, chat_history):
    """
    Args:
        user_message: string typed by user
        chat_history: list of tuples [(role, text), ...] from the gradio Chatbot
    
    Yields:
        chat_history list (so gradio updates UI). First yield shows "Thinking...",
        second yields the final assistant response.
    """
    # Ensure history is a list
    chat_history = chat_history or []
    
    # 1) Append user message to both display and model history
    chat_history.append(("User", user_message))
    # Convert to model history for tokenizer
    model_history = display_to_model_history(chat_history)
    
    # 2) Append "Thinking..." placeholder in UI and yield (so user sees it)
    chat_history.append(("Assistant", "Thinking..."))
    yield chat_history
    
    # 3) Build the prompt for the model using your tokenizer helper
    input_text = tokenizer.apply_chat_template(model_history, tokenize=False)
    
    # 4) Tokenize and run generation
    inputs = tokenizer.encode(input_text, return_tensors="pt", truncation=True).to(device)
    # Generate (tune args as you prefer)
    with torch.no_grad():
        outputs = model.generate(
            inputs,
            max_new_tokens=512,
            temperature=0.9,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )
    decoded = tokenizer.decode(outputs[0], skip_special_tokens=False)
    
    # 5) Extract assistant response (match your original splitting logic)
    # Keep the same delimiters you used previously (adjust if different)
    try:
        response = decoded.split("<|im_start|>assistant\n")[-1].split("<|im_end|>")[0].strip()
    except Exception:
        # Fallback: use last part of decoded text
        response = decoded.strip()
    
    # 6) Replace the "Thinking..." placeholder with final response
    # The placeholder was last element, so update it
    if chat_history and chat_history[-1][0].lower().startswith("assistant"):
        chat_history[-1] = ("Assistant", response)
    else:
        chat_history.append(("Assistant", response))
    
    # 7) Final yield with assistant output
    yield chat_history

# --- Gradio UI ---
with gr.Blocks() as demo:
    gr.Markdown("## Episteme Chat — shows 'Thinking...' then final assistant output")
    chatbot = gr.Chatbot(height=600)
    txt = gr.Textbox(show_label=False, placeholder="Type your message and hit Enter")
    clear = gr.Button("Clear")
    
    # Bind the generator to textbox submit
    txt.submit(predict, inputs=[txt, chatbot], outputs=chatbot)
    clear.click(lambda: None, None, chatbot, queue=False)  # clears chat (returns None)
    
demo.launch()