arubaDev's picture
Create app.py
33cfdca verified
raw
history blame
11.3 kB
import os
import sqlite3
import time
from datetime import datetime
import gradio as gr
from huggingface_hub import InferenceClient
# ---------------------------
# Config
# ---------------------------
MODELS = {
"Meta LLaMA 3.1 (8B Instruct)": "meta-llama/Llama-3.1-8B-Instruct",
"Mistral 7B Instruct": "mistralai/Mistral-7B-Instruct-v0.3",
}
HF_TOKEN = os.getenv("HF_TOKEN") # set this in your Space's Secrets
DB_PATH = "history.db"
SYSTEM_DEFAULT = (
"You are a coding assistant. Respond only with clean and complete code "
"unless explanation is explicitly requested. Prefer full CRUD scaffolds, "
"with files, paths, and commands when asked."
)
# ---------------------------
# DB Setup
# ---------------------------
def db():
conn = sqlite3.connect(DB_PATH)
conn.execute("PRAGMA journal_mode=WAL;")
return conn
def init_db():
conn = db()
cur = conn.cursor()
cur.execute("""
CREATE TABLE IF NOT EXISTS sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
title TEXT NOT NULL,
created_at TEXT NOT NULL
)
""")
cur.execute("""
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id INTEGER NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
created_at TEXT NOT NULL,
FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE
)
""")
conn.commit()
conn.close()
def create_session(title: str = "New chat") -> int:
conn = db()
cur = conn.cursor()
cur.execute(
"INSERT INTO sessions (title, created_at) VALUES (?, ?)",
(title, datetime.utcnow().isoformat())
)
session_id = cur.lastrowid
conn.commit()
conn.close()
return session_id
def delete_session(session_id: int):
conn = db()
cur = conn.cursor()
cur.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
cur.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
conn.commit()
conn.close()
def list_sessions():
conn = db()
cur = conn.cursor()
cur.execute("SELECT id, title FROM sessions ORDER BY id DESC")
rows = cur.fetchall()
conn.close()
labels = [f"{r[0]}{r[1]}" for r in rows]
return labels, rows
def get_messages(session_id: int):
conn = db()
cur = conn.cursor()
cur.execute("""
SELECT role, content FROM messages
WHERE session_id = ?
ORDER BY id ASC
""", (session_id,))
rows = cur.fetchall()
conn.close()
msgs = [{"role": role, "content": content} for (role, content) in rows]
return msgs
def add_message(session_id: int, role: str, content: str):
conn = db()
cur = conn.cursor()
cur.execute(
"INSERT INTO messages (session_id, role, content, created_at) VALUES (?, ?, ?, ?)",
(session_id, role, content, datetime.utcnow().isoformat())
)
conn.commit()
conn.close()
def update_session_title_if_needed(session_id: int, first_user_text: str):
conn = db()
cur = conn.cursor()
cur.execute("SELECT COUNT(*) FROM messages WHERE session_id=? AND role='user'", (session_id,))
count_users = cur.fetchone()[0]
if count_users == 1:
title = first_user_text.strip().split("\n")[0]
title = (title[:50] + "…") if len(title) > 50 else title
cur.execute("UPDATE sessions SET title=? WHERE id=?", (title or "New chat", session_id))
conn.commit()
conn.close()
# ---------------------------
# Helpers
# ---------------------------
def label_to_id(label: str | None) -> int | None:
if not label:
return None
try:
return int(label.split("•", 1)[0].strip())
except Exception:
return None
def build_api_messages(session_id: int, system_message: str):
msgs = [{"role": "system", "content": system_message.strip()}]
msgs.extend(get_messages(session_id))
return msgs
def get_client(model_choice: str):
"""Return the right InferenceClient for the chosen model."""
model_id = MODELS.get(model_choice, list(MODELS.values())[0])
return InferenceClient(model_id, token=HF_TOKEN)
# ---------------------------
# Gradio Callbacks
# ---------------------------
def refresh_sessions_cb():
labels, _ = list_sessions()
selected = labels[0] if labels else None
return gr.update(choices=labels, value=selected)
def new_chat_cb():
sid = create_session("New chat")
labels, _ = list_sessions()
selected = next((lbl for lbl in labels if lbl.startswith(f"{sid} ")), None)
return (gr.update(choices=labels, value=selected), [], "")
def load_session_cb(selected_label):
sid = label_to_id(selected_label)
if not sid:
return []
return get_messages(sid)
def delete_chat_cb(selected_label):
sid = label_to_id(selected_label)
if sid:
delete_session(sid)
labels, _ = list_sessions()
selected = labels[0] if labels else None
return gr.update(choices=labels, value=selected), []
def send_cb(user_text, selected_label, chatbot_msgs, system_message, max_tokens, temperature, top_p, model_choice):
sid = label_to_id(selected_label)
if sid is None:
sid = create_session("New chat")
labels, _ = list_sessions()
selected_label = next((lbl for lbl in labels if lbl.startswith(f"{sid} ")), None)
add_message(sid, "user", user_text)
update_session_title_if_needed(sid, user_text)
api_messages = build_api_messages(sid, system_message)
display_msgs = get_messages(sid)
display_msgs.append({"role": "assistant", "content": ""})
client = get_client(model_choice)
partial = ""
try:
for chunk in client.chat_completion(
messages=api_messages,
max_tokens=int(max_tokens),
temperature=float(temperature),
top_p=float(top_p),
stream=True,
):
delta = chunk.choices[0].delta.content or ""
if delta:
partial += delta
display_msgs[-1]["content"] = partial
yield (display_msgs, "", selected_label)
add_message(sid, "assistant", partial)
except Exception as e:
err = f"⚠️ Error: {str(e)}"
display_msgs[-1]["content"] = err
yield (display_msgs, "", selected_label)
def regenerate_cb(selected_label, system_message, max_tokens, temperature, top_p, model_choice):
sid = label_to_id(selected_label)
if sid is None:
return [], ""
msgs = get_messages(sid)
if not msgs:
return [], ""
if msgs and msgs[-1]["role"] == "assistant":
conn = db()
cur = conn.cursor()
cur.execute("""
DELETE FROM messages
WHERE id = (
SELECT id FROM messages WHERE session_id=? ORDER BY id DESC LIMIT 1
)
""", (sid,))
conn.commit()
conn.close()
msgs = get_messages(sid)
api_messages = [{"role": "system", "content": system_message.strip()}] + msgs
display_msgs = msgs + [{"role": "assistant", "content": ""}]
client = get_client(model_choice)
partial = ""
try:
for chunk in client.chat_completion(
messages=api_messages,
max_tokens=int(max_tokens),
temperature=float(temperature),
top_p=float(top_p),
stream=True,
):
delta = chunk.choices[0].delta.content or ""
if delta:
partial += delta
display_msgs[-1]["content"] = partial
yield display_msgs
add_message(sid, "assistant", partial)
except Exception as e:
display_msgs[-1]["content"] = f"⚠️ Error: {str(e)}"
yield display_msgs
# ---------------------------
# App UI
# ---------------------------
init_db()
labels, _ = list_sessions()
if not labels:
first_sid = create_session("New chat")
labels, _ = list_sessions()
default_selected = labels[0] if labels else None
with gr.Blocks(title="LLaMA/Mistral CRUD Automation (with History)", theme=gr.themes.Soft()) as demo:
# --- Updated CSS to make ALL buttons green ---
gr.HTML("""
<style>
button {
background-color: #22c55e !important;
color: #ffffff !important;
border: none !important;
}
button:hover {
background-color: #16a34a !important;
}
button:focus {
outline: 2px solid #166534 !important;
outline-offset: 2px;
}
</style>
""")
gr.Markdown("## 🦙🤖 LLaMA & Mistral CRUD Automation — with Persistent History")
with gr.Row():
with gr.Column(scale=1, min_width=260):
gr.Markdown("### 📁 Sessions")
session_list = gr.Radio(
choices=labels,
value=default_selected,
label="Your chats",
interactive=True
)
with gr.Row():
new_btn = gr.Button("➕ New Chat", variant="primary")
del_btn = gr.Button("🗑️ Delete", variant="stop")
refresh_btn = gr.Button("🔄 Refresh", variant="secondary")
gr.Markdown("### 🤖 Model Selection")
model_choice = gr.Dropdown(
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0],
label="Choose a model",
interactive=True
)
gr.Markdown("### ⚙️ Generation Settings")
system_box = gr.Textbox(
value=SYSTEM_DEFAULT,
label="System message",
lines=4
)
max_tokens = gr.Slider(256, 4096, value=1200, step=16, label="Max tokens")
temperature = gr.Slider(0.0, 2.0, value=0.25, step=0.05, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
with gr.Column(scale=3):
chatbot = gr.Chatbot(label="Assistant", height=520, type="messages")
with gr.Row():
user_box = gr.Textbox(placeholder="Describe your CRUD task…", lines=3, scale=5)
with gr.Row():
send_btn = gr.Button("Send ▶️", variant="primary")
regen_btn = gr.Button("Regenerate 🔁", variant="secondary")
# Interactions
refresh_btn.click(refresh_sessions_cb, outputs=session_list)
new_btn.click(new_chat_cb, outputs=[session_list, chatbot, user_box])
del_btn.click(delete_chat_cb, inputs=session_list, outputs=[session_list, chatbot])
session_list.change(load_session_cb, inputs=session_list, outputs=chatbot)
send_btn.click(
send_cb,
inputs=[user_box, session_list, chatbot, system_box, max_tokens, temperature, top_p, model_choice],
outputs=[chatbot, user_box, session_list]
)
user_box.submit(
send_cb,
inputs=[user_box, session_list, chatbot, system_box, max_tokens, temperature, top_p, model_choice],
outputs=[chatbot, user_box, session_list]
)
regen_btn.click(
regenerate_cb,
inputs=[session_list, system_box, max_tokens, temperature, top_p, model_choice],
outputs=chatbot
)
if __name__ == "__main__":
demo.launch()