Spaces:
Sleeping
Sleeping
| import os | |
| import sqlite3 | |
| import time | |
| from datetime import datetime | |
| import gradio as gr | |
| from huggingface_hub import InferenceClient | |
| # --------------------------- | |
| # Config | |
| # --------------------------- | |
| MODELS = { | |
| "Meta LLaMA 3.1 (8B Instruct)": "meta-llama/Llama-3.1-8B-Instruct", | |
| "Mistral 7B Instruct": "mistralai/Mistral-7B-Instruct-v0.3", | |
| } | |
| HF_TOKEN = os.getenv("HF_TOKEN") # set this in your Space's Secrets | |
| DB_PATH = "history.db" | |
| SYSTEM_DEFAULT = ( | |
| "You are a coding assistant. Respond only with clean and complete code " | |
| "unless explanation is explicitly requested. Prefer full CRUD scaffolds, " | |
| "with files, paths, and commands when asked." | |
| ) | |
| # --------------------------- | |
| # DB Setup | |
| # --------------------------- | |
| def db(): | |
| conn = sqlite3.connect(DB_PATH) | |
| conn.execute("PRAGMA journal_mode=WAL;") | |
| return conn | |
| def init_db(): | |
| conn = db() | |
| cur = conn.cursor() | |
| cur.execute(""" | |
| CREATE TABLE IF NOT EXISTS sessions ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| title TEXT NOT NULL, | |
| created_at TEXT NOT NULL | |
| ) | |
| """) | |
| cur.execute(""" | |
| CREATE TABLE IF NOT EXISTS messages ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| session_id INTEGER NOT NULL, | |
| role TEXT NOT NULL, | |
| content TEXT NOT NULL, | |
| created_at TEXT NOT NULL, | |
| FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE | |
| ) | |
| """) | |
| conn.commit() | |
| conn.close() | |
| def create_session(title: str = "New chat") -> int: | |
| conn = db() | |
| cur = conn.cursor() | |
| cur.execute( | |
| "INSERT INTO sessions (title, created_at) VALUES (?, ?)", | |
| (title, datetime.utcnow().isoformat()) | |
| ) | |
| session_id = cur.lastrowid | |
| conn.commit() | |
| conn.close() | |
| return session_id | |
| def delete_session(session_id: int): | |
| conn = db() | |
| cur = conn.cursor() | |
| cur.execute("DELETE FROM messages WHERE session_id = ?", (session_id,)) | |
| cur.execute("DELETE FROM sessions WHERE id = ?", (session_id,)) | |
| conn.commit() | |
| conn.close() | |
| def list_sessions(): | |
| conn = db() | |
| cur = conn.cursor() | |
| cur.execute("SELECT id, title FROM sessions ORDER BY id DESC") | |
| rows = cur.fetchall() | |
| conn.close() | |
| labels = [f"{r[0]} • {r[1]}" for r in rows] | |
| return labels, rows | |
| def get_messages(session_id: int): | |
| conn = db() | |
| cur = conn.cursor() | |
| cur.execute(""" | |
| SELECT role, content FROM messages | |
| WHERE session_id = ? | |
| ORDER BY id ASC | |
| """, (session_id,)) | |
| rows = cur.fetchall() | |
| conn.close() | |
| msgs = [{"role": role, "content": content} for (role, content) in rows] | |
| return msgs | |
| def add_message(session_id: int, role: str, content: str): | |
| conn = db() | |
| cur = conn.cursor() | |
| cur.execute( | |
| "INSERT INTO messages (session_id, role, content, created_at) VALUES (?, ?, ?, ?)", | |
| (session_id, role, content, datetime.utcnow().isoformat()) | |
| ) | |
| conn.commit() | |
| conn.close() | |
| def update_session_title_if_needed(session_id: int, first_user_text: str): | |
| conn = db() | |
| cur = conn.cursor() | |
| cur.execute("SELECT COUNT(*) FROM messages WHERE session_id=? AND role='user'", (session_id,)) | |
| count_users = cur.fetchone()[0] | |
| if count_users == 1: | |
| title = first_user_text.strip().split("\n")[0] | |
| title = (title[:50] + "…") if len(title) > 50 else title | |
| cur.execute("UPDATE sessions SET title=? WHERE id=?", (title or "New chat", session_id)) | |
| conn.commit() | |
| conn.close() | |
| # --------------------------- | |
| # Helpers | |
| # --------------------------- | |
| def label_to_id(label: str | None) -> int | None: | |
| if not label: | |
| return None | |
| try: | |
| return int(label.split("•", 1)[0].strip()) | |
| except Exception: | |
| return None | |
| def build_api_messages(session_id: int, system_message: str): | |
| msgs = [{"role": "system", "content": system_message.strip()}] | |
| msgs.extend(get_messages(session_id)) | |
| return msgs | |
| def get_client(model_choice: str): | |
| """Return the right InferenceClient for the chosen model.""" | |
| model_id = MODELS.get(model_choice, list(MODELS.values())[0]) | |
| return InferenceClient(model_id, token=HF_TOKEN) | |
| # --------------------------- | |
| # Gradio Callbacks | |
| # --------------------------- | |
| def refresh_sessions_cb(): | |
| labels, _ = list_sessions() | |
| selected = labels[0] if labels else None | |
| return gr.update(choices=labels, value=selected) | |
| def new_chat_cb(): | |
| sid = create_session("New chat") | |
| labels, _ = list_sessions() | |
| selected = next((lbl for lbl in labels if lbl.startswith(f"{sid} ")), None) | |
| return (gr.update(choices=labels, value=selected), [], "") | |
| def load_session_cb(selected_label): | |
| sid = label_to_id(selected_label) | |
| if not sid: | |
| return [] | |
| return get_messages(sid) | |
| def delete_chat_cb(selected_label): | |
| sid = label_to_id(selected_label) | |
| if sid: | |
| delete_session(sid) | |
| labels, _ = list_sessions() | |
| selected = labels[0] if labels else None | |
| return gr.update(choices=labels, value=selected), [] | |
| def send_cb(user_text, selected_label, chatbot_msgs, system_message, max_tokens, temperature, top_p, model_choice): | |
| sid = label_to_id(selected_label) | |
| if sid is None: | |
| sid = create_session("New chat") | |
| labels, _ = list_sessions() | |
| selected_label = next((lbl for lbl in labels if lbl.startswith(f"{sid} ")), None) | |
| add_message(sid, "user", user_text) | |
| update_session_title_if_needed(sid, user_text) | |
| api_messages = build_api_messages(sid, system_message) | |
| display_msgs = get_messages(sid) | |
| display_msgs.append({"role": "assistant", "content": ""}) | |
| client = get_client(model_choice) | |
| partial = "" | |
| try: | |
| for chunk in client.chat_completion( | |
| messages=api_messages, | |
| max_tokens=int(max_tokens), | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| stream=True, | |
| ): | |
| delta = chunk.choices[0].delta.content or "" | |
| if delta: | |
| partial += delta | |
| display_msgs[-1]["content"] = partial | |
| yield (display_msgs, "", selected_label) | |
| add_message(sid, "assistant", partial) | |
| except Exception as e: | |
| err = f"⚠️ Error: {str(e)}" | |
| display_msgs[-1]["content"] = err | |
| yield (display_msgs, "", selected_label) | |
| def regenerate_cb(selected_label, system_message, max_tokens, temperature, top_p, model_choice): | |
| sid = label_to_id(selected_label) | |
| if sid is None: | |
| return [], "" | |
| msgs = get_messages(sid) | |
| if not msgs: | |
| return [], "" | |
| if msgs and msgs[-1]["role"] == "assistant": | |
| conn = db() | |
| cur = conn.cursor() | |
| cur.execute(""" | |
| DELETE FROM messages | |
| WHERE id = ( | |
| SELECT id FROM messages WHERE session_id=? ORDER BY id DESC LIMIT 1 | |
| ) | |
| """, (sid,)) | |
| conn.commit() | |
| conn.close() | |
| msgs = get_messages(sid) | |
| api_messages = [{"role": "system", "content": system_message.strip()}] + msgs | |
| display_msgs = msgs + [{"role": "assistant", "content": ""}] | |
| client = get_client(model_choice) | |
| partial = "" | |
| try: | |
| for chunk in client.chat_completion( | |
| messages=api_messages, | |
| max_tokens=int(max_tokens), | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| stream=True, | |
| ): | |
| delta = chunk.choices[0].delta.content or "" | |
| if delta: | |
| partial += delta | |
| display_msgs[-1]["content"] = partial | |
| yield display_msgs | |
| add_message(sid, "assistant", partial) | |
| except Exception as e: | |
| display_msgs[-1]["content"] = f"⚠️ Error: {str(e)}" | |
| yield display_msgs | |
| # --------------------------- | |
| # App UI | |
| # --------------------------- | |
| init_db() | |
| labels, _ = list_sessions() | |
| if not labels: | |
| first_sid = create_session("New chat") | |
| labels, _ = list_sessions() | |
| default_selected = labels[0] if labels else None | |
| with gr.Blocks(title="LLaMA/Mistral CRUD Automation (with History)", theme=gr.themes.Soft()) as demo: | |
| # --- Updated CSS to make ALL buttons green --- | |
| gr.HTML(""" | |
| <style> | |
| button { | |
| background-color: #22c55e !important; | |
| color: #ffffff !important; | |
| border: none !important; | |
| } | |
| button:hover { | |
| background-color: #16a34a !important; | |
| } | |
| button:focus { | |
| outline: 2px solid #166534 !important; | |
| outline-offset: 2px; | |
| } | |
| </style> | |
| """) | |
| gr.Markdown("## 🦙🤖 LLaMA & Mistral CRUD Automation — with Persistent History") | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=260): | |
| gr.Markdown("### 📁 Sessions") | |
| session_list = gr.Radio( | |
| choices=labels, | |
| value=default_selected, | |
| label="Your chats", | |
| interactive=True | |
| ) | |
| with gr.Row(): | |
| new_btn = gr.Button("➕ New Chat", variant="primary") | |
| del_btn = gr.Button("🗑️ Delete", variant="stop") | |
| refresh_btn = gr.Button("🔄 Refresh", variant="secondary") | |
| gr.Markdown("### 🤖 Model Selection") | |
| model_choice = gr.Dropdown( | |
| choices=list(MODELS.keys()), | |
| value=list(MODELS.keys())[0], | |
| label="Choose a model", | |
| interactive=True | |
| ) | |
| gr.Markdown("### ⚙️ Generation Settings") | |
| system_box = gr.Textbox( | |
| value=SYSTEM_DEFAULT, | |
| label="System message", | |
| lines=4 | |
| ) | |
| max_tokens = gr.Slider(256, 4096, value=1200, step=16, label="Max tokens") | |
| temperature = gr.Slider(0.0, 2.0, value=0.25, step=0.05, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot(label="Assistant", height=520, type="messages") | |
| with gr.Row(): | |
| user_box = gr.Textbox(placeholder="Describe your CRUD task…", lines=3, scale=5) | |
| with gr.Row(): | |
| send_btn = gr.Button("Send ▶️", variant="primary") | |
| regen_btn = gr.Button("Regenerate 🔁", variant="secondary") | |
| # Interactions | |
| refresh_btn.click(refresh_sessions_cb, outputs=session_list) | |
| new_btn.click(new_chat_cb, outputs=[session_list, chatbot, user_box]) | |
| del_btn.click(delete_chat_cb, inputs=session_list, outputs=[session_list, chatbot]) | |
| session_list.change(load_session_cb, inputs=session_list, outputs=chatbot) | |
| send_btn.click( | |
| send_cb, | |
| inputs=[user_box, session_list, chatbot, system_box, max_tokens, temperature, top_p, model_choice], | |
| outputs=[chatbot, user_box, session_list] | |
| ) | |
| user_box.submit( | |
| send_cb, | |
| inputs=[user_box, session_list, chatbot, system_box, max_tokens, temperature, top_p, model_choice], | |
| outputs=[chatbot, user_box, session_list] | |
| ) | |
| regen_btn.click( | |
| regenerate_cb, | |
| inputs=[session_list, system_box, max_tokens, temperature, top_p, model_choice], | |
| outputs=chatbot | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |