arubaDev's picture
Update app.py
41c5c08 verified
import os
import sqlite3
from datetime import datetime
import gradio as gr
from huggingface_hub import InferenceClient
from datasets import load_dataset
# ---------------------------
# Config
# ---------------------------
MODELS = {
"Meta LLaMA 3.1 (8B Instruct)": "meta-llama/Llama-3.1-8B-Instruct",
"Mistral 7B Instruct": "mistralai/Mistral-7B-Instruct-v0.3",
}
DATASETS = ["The Stack", "CodeXGLUE"]
HF_TOKEN = os.getenv("HF_TOKEN")
DB_PATH = "history.db"
SYSTEM_DEFAULT = (
"Specializes in databases, APIs, auth, CRUD. "
"Provides complete backend code scaffolds. "
"Declines frontend-heavy requests."
)
# ---------------------------
# DB Setup
# ---------------------------
def db(): return sqlite3.connect(DB_PATH)
def init_db():
conn = db()
cur = conn.cursor()
cur.execute("""
CREATE TABLE IF NOT EXISTS sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
title TEXT NOT NULL,
created_at TEXT NOT NULL
)
""")
cur.execute("""
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id INTEGER NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
created_at TEXT NOT NULL,
FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE
)
""")
conn.commit()
conn.close()
def create_session(title: str = "New chat") -> int:
conn = db()
cur = conn.cursor()
cur.execute(
"INSERT INTO sessions (title, created_at) VALUES (?, ?)",
(title, datetime.utcnow().isoformat())
)
sid = cur.lastrowid
conn.commit()
conn.close()
return sid
def delete_session(session_id: int):
conn = db()
cur = conn.cursor()
cur.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
cur.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
conn.commit()
conn.close()
def list_sessions():
conn = db()
cur = conn.cursor()
cur.execute("SELECT id, title FROM sessions ORDER BY id DESC")
rows = cur.fetchall()
conn.close()
labels = [f"{r[0]}{r[1]}" for r in rows]
return labels, rows
def get_messages(session_id: int):
conn = db()
cur = conn.cursor()
cur.execute("""
SELECT role, content
FROM messages
WHERE session_id = ?
ORDER BY id ASC
""", (session_id,))
msgs = [{"role": role, "content": content} for (role, content) in cur.fetchall()]
conn.close()
return msgs
def add_message(session_id: int, role: str, content: str):
conn = db()
cur = conn.cursor()
cur.execute(
"INSERT INTO messages (session_id, role, content, created_at) VALUES (?, ?, ?, ?)",
(session_id, role, content, datetime.utcnow().isoformat())
)
conn.commit()
conn.close()
def update_session_title_if_needed(session_id: int, first_user_text: str):
conn = db()
cur = conn.cursor()
cur.execute("SELECT COUNT(*) FROM messages WHERE session_id=? AND role='user'", (session_id,))
if cur.fetchone()[0] == 1:
title = first_user_text.strip().split("\n")[0]
title = (title[:50] + "…") if len(title) > 50 else title
cur.execute("UPDATE sessions SET title=? WHERE id=?", (title or "New chat", session_id))
conn.commit()
conn.close()
def label_to_id(label: str | None) -> int | None:
if not label: return None
try: return int(label.split("•", 1)[0].strip())
except: return None
def build_api_messages(session_id: int, system_message: str):
msgs = [{"role": "system", "content": system_message.strip()}]
msgs.extend(get_messages(session_id))
return msgs
def get_client(model_choice: str):
return InferenceClient(MODELS.get(model_choice, list(MODELS.values())[0]), token=HF_TOKEN)
def load_dataset_by_name(name: str):
if name == "The Stack": return load_dataset("bigcode/the-stack", split="train")
if name == "CodeXGLUE": return load_dataset("google/code_x_glue_cc_code_to_code_trans", split="train")
return None
FRONTEND_KEYWORDS = ["react", "vue", "angular", "html", "css", "javascript", "tailwind", "recharts", "typescript"]
def is_frontend_request(user_text: str) -> bool:
return any(kw in user_text.lower() for kw in FRONTEND_KEYWORDS)
# ---------------------------
# Callbacks
# ---------------------------
def refresh_sessions_cb():
labels, _ = list_sessions()
selected = labels[0] if labels else None
visible = bool(selected)
return (
gr.update(choices=labels, value=selected),
gr.update(visible=visible),
gr.update(visible=visible),
gr.update(visible=visible),
gr.update(visible=visible),
)
def new_chat_cb():
sid = create_session("New chat")
labels, _ = list_sessions()
selected = next((lbl for lbl in labels if lbl.startswith(f"{sid} ")), None)
return (
gr.update(choices=labels, value=selected),
[], "", # chatbot cleared, user box cleared
gr.update(visible=True), gr.update(visible=True),
gr.update(visible=True), gr.update(visible=True),
)
def load_session_cb(selected_label):
sid = label_to_id(selected_label)
if not sid: return [], gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
return get_messages(sid), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
def delete_chat_cb(selected_label):
sid = label_to_id(selected_label)
if sid: delete_session(sid)
labels, _ = list_sessions()
selected = labels[0] if labels else None
visible = bool(selected)
return gr.update(choices=labels, value=selected), [], gr.update(visible=visible), gr.update(visible=visible), gr.update(visible=visible), gr.update(visible=visible)
# --- Send ---
def send_cb(user_text, selected_label, chatbot_msgs, system_message, max_tokens, temperature, top_p, model_choice, dataset_choice, *args):
sid = label_to_id(selected_label) or create_session("New chat")
add_message(sid, "user", user_text)
update_session_title_if_needed(sid, user_text)
display_msgs = chatbot_msgs[:] + [{"role": "user", "content": user_text}]
if is_frontend_request(user_text):
apology = "⚠️ I'm a backend-focused assistant and cannot provide frontend code."
display_msgs.append({"role": "assistant", "content": apology})
add_message(sid, "assistant", apology)
yield display_msgs, "", selected_label
return
display_msgs.append({"role": "assistant", "content": "…"})
yield display_msgs, "", selected_label
client = get_client(model_choice)
api_messages = build_api_messages(sid, system_message)
partial = ""
try:
for chunk in client.chat_completion(messages=api_messages, max_tokens=int(max_tokens), temperature=float(temperature), top_p=float(top_p), stream=True):
if not hasattr(chunk, "choices") or not chunk.choices: continue
choice = chunk.choices[0]
delta = getattr(getattr(choice, "delta", None), "content", None) or getattr(getattr(choice, "message", None), "content", None) or ""
if delta:
partial += delta
display_msgs[-1]["content"] = partial
yield display_msgs, "", selected_label
add_message(sid, "assistant", partial)
except Exception as e:
display_msgs[-1]["content"] = f"⚠️ Error: {e}"
yield display_msgs, "", selected_label
# --- Regenerate ---
def regenerate_cb(selected_label, system_message, max_tokens, temperature, top_p, model_choice, dataset_choice):
sid = label_to_id(selected_label)
if not sid: return [], ""
msgs = get_messages(sid)
if msgs and msgs[-1]["role"] == "assistant":
conn = db()
cur = conn.cursor()
cur.execute("DELETE FROM messages WHERE id=(SELECT id FROM messages WHERE session_id=? ORDER BY id DESC LIMIT 1)", (sid,))
conn.commit()
conn.close()
msgs = get_messages(sid)
display_msgs = msgs + [{"role": "assistant", "content": ""}]
client = get_client(model_choice)
partial = ""
try:
for chunk in client.chat_completion(messages=[{"role": "system", "content": system_message.strip()}]+msgs, max_tokens=int(max_tokens), temperature=float(temperature), top_p=float(top_p), stream=True):
if not hasattr(chunk, "choices") or not chunk.choices: continue
delta = getattr(getattr(chunk.choices[0], "delta", None), "content", None) or getattr(getattr(chunk.choices[0], "message", None), "content", None) or ""
if delta:
partial += delta
display_msgs[-1]["content"] = partial
yield display_msgs
add_message(sid, "assistant", partial)
except Exception as e:
display_msgs[-1]["content"] = f"⚠️ Error: {e}"
yield display_msgs, "", selected_label
# ---------------------------
# App UI
# ---------------------------
init_db()
labels, _ = list_sessions()
if not labels: create_session("New chat"); labels, _ = list_sessions()
default_selected = labels[0] if labels else None
with gr.Blocks(title="Backend-Focused LLaMA/Mistral CRUD Assistant", theme=gr.themes.Soft()) as demo:
# --- Custom CSS ---
gr.HTML("""
<style>
/* ========== Compact section headings ========== */
.compact-heading h3 {
font-size: 0.9rem !important;
font-weight: 600 !important;
margin: 0.3rem 0 !important;
padding: 0 !important;
}
/* ========== Main top title smaller ========== */
.main-title h2 {
font-size: 1.2rem !important;
font-weight: 700 !important;
margin-bottom: 0.6rem !important;
}
/* ========== Reduce spacing below sliders & dropdowns ========== */
.compact-sliders .gr-slider,
.compact-sliders .gr-number,
.compact-sliders .gr-dropdown {
margin-bottom: 0.4rem !important;
}
/* ========== Compact the left panel column ========== */
.left-panel .gr-column {
gap: 0.3rem !important;
}
/* ========== Tiny / pressed buttons ========== */
.tiny-btn .gr-button {
font-size: 0.75rem !important; /* smaller text */
padding: 2px 6px !important; /* tight padding */
min-height: 24px !important; /* shorter button */
line-height: 1.1 !important; /* compact vertical spacing */
border-radius: 4px !important; /* slightly rounded */
}
/* ========== Remove extra top/left page gaps ========== */
body, html {
margin: 0 !important;
padding: 0 !important;
}
.gradio-container {
margin-top: -40px !important; /* was -10px, move up more */
margin-left: -40px !important;
padding-top: 0 !important;
padding-left: 0 !important;
}
.main-title,
.gr-block:first-of-type {
margin-top: -10px !important; /* was -5px, pull up further */
margin-left: -10 !important;
padding-top: 0 !important;
padding-left: 0 !important;
}
/* ========== Compact sliders: reduce vertical height and spacing ========== */
.compact-sliders .gr-slider {
height: 28px !important; /* shorter track */
padding: 0 !important; /* remove inner padding */
margin-bottom: 0.3rem !important; /* less spacing below each slider */
}
/* Compact slider label and value text */
.compact-sliders .gr-slider label,
.compact-sliders .gr-slider span {
font-size: 0.8rem !important; /* smaller text */
}
/* Tighten the numeric input box on sliders */
.compact-sliders input[type="number"] {
max-width: 70px !important; /* smaller numeric input width */
font-size: 0.8rem !important;
padding: 2px 4px !important;
}
</style>
""")
# gr.Markdown("## 🗄️ LLaMA & Mistral Backend-Focused CRUD Automation — Persistent History")
with gr.Row(equal_height=True):
with gr.Column(scale=1, min_width=260):
# gr.Markdown("### 📁 Sessions")
session_list = gr.Radio(choices=labels, value=default_selected, label="Your Chats", interactive=True)
with gr.Row(elem_id="session-actions-row"):
new_btn = gr.Button("New Chat")
rename_btn = gr.Button("Rename", visible=False)
save_title_btn = gr.Button("Save", visible=False)
del_btn = gr.Button("Delete", visible=False)
refresh_btn = gr.Button("Refresh", visible=False)
edit_title_box = gr.Textbox(label="Edit Chat Name", placeholder="Type new chat name…", visible=False)
# Model selection
# gr.Markdown("### 🤖 Model Selection" ,elem_classes="compact-heading")
model_choice = gr.Dropdown(choices=list(MODELS.keys()), value=list(MODELS.keys())[0], label="Choose a model", interactive=True)
# Dataset selection
# gr.Markdown("### 📚 Dataset Selection" ,elem_classes="compact-heading")
dataset_choice = gr.Dropdown(choices=DATASETS, value=DATASETS[0], label="Select a dataset", interactive=True)
# Generation settings
# gr.Markdown("### ⚙️ Generation Settings" ,elem_classes="compact-heading")
with gr.Group(elem_classes="system-message-compact"):
system_box = gr.Textbox(value=SYSTEM_DEFAULT, label="System message", lines=1)
with gr.Group(elem_classes="compact-sliders"):
max_tokens = gr.Slider(256, 4096, value=1200, step=16, label="Max tokens")
temperature = gr.Slider(0.0, 2.0, value=0.25, step=0.05, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
with gr.Column(scale=3):
chatbot = gr.Chatbot(label="Assistant", type="messages")
with gr.Row(elem_classes="user-input-row"):
user_box = gr.Textbox(placeholder="Describe your CRUD/backend task…", lines=3, scale=5)
send_btn = gr.Button("▶️", variant="primary", scale=1)
regen_btn = gr.Button("🔁", variant="secondary", scale=1)
# --- Wire callbacks ---
refresh_btn.click(refresh_sessions_cb, outputs=[session_list, edit_title_box, save_title_btn, del_btn, refresh_btn])
new_btn.click(new_chat_cb, outputs=[session_list, chatbot, user_box, edit_title_box, save_title_btn, del_btn, refresh_btn])
del_btn.click(delete_chat_cb, inputs=session_list, outputs=[session_list, chatbot, edit_title_box, save_title_btn, del_btn, refresh_btn])
session_list.change(load_session_cb, inputs=session_list, outputs=[chatbot, edit_title_box, save_title_btn, del_btn, refresh_btn])
send_btn.click(send_cb, inputs=[user_box, session_list, chatbot, system_box, max_tokens, temperature, top_p, model_choice, dataset_choice], outputs=[chatbot, user_box, session_list])
user_box.submit(send_cb, inputs=[user_box, session_list, chatbot, system_box, max_tokens, temperature, top_p, model_choice, dataset_choice], outputs=[chatbot, user_box, session_list])
regen_btn.click(regenerate_cb, inputs=[session_list, system_box, max_tokens, temperature, top_p, model_choice, dataset_choice], outputs=chatbot)
if __name__ == "__main__":
demo.launch()