Spaces:
Sleeping
Sleeping
| import os | |
| import sqlite3 | |
| from datetime import datetime | |
| import gradio as gr | |
| from huggingface_hub import InferenceClient | |
| from datasets import load_dataset | |
| # --------------------------- | |
| # Config | |
| # --------------------------- | |
| MODELS = { | |
| "Meta LLaMA 3.1 (8B Instruct)": "meta-llama/Llama-3.1-8B-Instruct", | |
| "Mistral 7B Instruct": "mistralai/Mistral-7B-Instruct-v0.3", | |
| } | |
| DATASETS = ["The Stack", "CodeXGLUE"] | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| DB_PATH = "history.db" | |
| SYSTEM_DEFAULT = ( | |
| "Specializes in databases, APIs, auth, CRUD. " | |
| "Provides complete backend code scaffolds. " | |
| "Declines frontend-heavy requests." | |
| ) | |
| # --------------------------- | |
| # DB Setup | |
| # --------------------------- | |
| def db(): return sqlite3.connect(DB_PATH) | |
| def init_db(): | |
| conn = db() | |
| cur = conn.cursor() | |
| cur.execute(""" | |
| CREATE TABLE IF NOT EXISTS sessions ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| title TEXT NOT NULL, | |
| created_at TEXT NOT NULL | |
| ) | |
| """) | |
| cur.execute(""" | |
| CREATE TABLE IF NOT EXISTS messages ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| session_id INTEGER NOT NULL, | |
| role TEXT NOT NULL, | |
| content TEXT NOT NULL, | |
| created_at TEXT NOT NULL, | |
| FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE | |
| ) | |
| """) | |
| conn.commit() | |
| conn.close() | |
| def create_session(title: str = "New chat") -> int: | |
| conn = db() | |
| cur = conn.cursor() | |
| cur.execute( | |
| "INSERT INTO sessions (title, created_at) VALUES (?, ?)", | |
| (title, datetime.utcnow().isoformat()) | |
| ) | |
| sid = cur.lastrowid | |
| conn.commit() | |
| conn.close() | |
| return sid | |
| def delete_session(session_id: int): | |
| conn = db() | |
| cur = conn.cursor() | |
| cur.execute("DELETE FROM messages WHERE session_id = ?", (session_id,)) | |
| cur.execute("DELETE FROM sessions WHERE id = ?", (session_id,)) | |
| conn.commit() | |
| conn.close() | |
| def list_sessions(): | |
| conn = db() | |
| cur = conn.cursor() | |
| cur.execute("SELECT id, title FROM sessions ORDER BY id DESC") | |
| rows = cur.fetchall() | |
| conn.close() | |
| labels = [f"{r[0]} • {r[1]}" for r in rows] | |
| return labels, rows | |
| def get_messages(session_id: int): | |
| conn = db() | |
| cur = conn.cursor() | |
| cur.execute(""" | |
| SELECT role, content | |
| FROM messages | |
| WHERE session_id = ? | |
| ORDER BY id ASC | |
| """, (session_id,)) | |
| msgs = [{"role": role, "content": content} for (role, content) in cur.fetchall()] | |
| conn.close() | |
| return msgs | |
| def add_message(session_id: int, role: str, content: str): | |
| conn = db() | |
| cur = conn.cursor() | |
| cur.execute( | |
| "INSERT INTO messages (session_id, role, content, created_at) VALUES (?, ?, ?, ?)", | |
| (session_id, role, content, datetime.utcnow().isoformat()) | |
| ) | |
| conn.commit() | |
| conn.close() | |
| def update_session_title_if_needed(session_id: int, first_user_text: str): | |
| conn = db() | |
| cur = conn.cursor() | |
| cur.execute("SELECT COUNT(*) FROM messages WHERE session_id=? AND role='user'", (session_id,)) | |
| if cur.fetchone()[0] == 1: | |
| title = first_user_text.strip().split("\n")[0] | |
| title = (title[:50] + "…") if len(title) > 50 else title | |
| cur.execute("UPDATE sessions SET title=? WHERE id=?", (title or "New chat", session_id)) | |
| conn.commit() | |
| conn.close() | |
| def label_to_id(label: str | None) -> int | None: | |
| if not label: return None | |
| try: return int(label.split("•", 1)[0].strip()) | |
| except: return None | |
| def build_api_messages(session_id: int, system_message: str): | |
| msgs = [{"role": "system", "content": system_message.strip()}] | |
| msgs.extend(get_messages(session_id)) | |
| return msgs | |
| def get_client(model_choice: str): | |
| return InferenceClient(MODELS.get(model_choice, list(MODELS.values())[0]), token=HF_TOKEN) | |
| def load_dataset_by_name(name: str): | |
| if name == "The Stack": return load_dataset("bigcode/the-stack", split="train") | |
| if name == "CodeXGLUE": return load_dataset("google/code_x_glue_cc_code_to_code_trans", split="train") | |
| return None | |
| FRONTEND_KEYWORDS = ["react", "vue", "angular", "html", "css", "javascript", "tailwind", "recharts", "typescript"] | |
| def is_frontend_request(user_text: str) -> bool: | |
| return any(kw in user_text.lower() for kw in FRONTEND_KEYWORDS) | |
| # --------------------------- | |
| # Callbacks | |
| # --------------------------- | |
| def refresh_sessions_cb(): | |
| labels, _ = list_sessions() | |
| selected = labels[0] if labels else None | |
| visible = bool(selected) | |
| return ( | |
| gr.update(choices=labels, value=selected), | |
| gr.update(visible=visible), | |
| gr.update(visible=visible), | |
| gr.update(visible=visible), | |
| gr.update(visible=visible), | |
| ) | |
| def new_chat_cb(): | |
| sid = create_session("New chat") | |
| labels, _ = list_sessions() | |
| selected = next((lbl for lbl in labels if lbl.startswith(f"{sid} ")), None) | |
| return ( | |
| gr.update(choices=labels, value=selected), | |
| [], "", # chatbot cleared, user box cleared | |
| gr.update(visible=True), gr.update(visible=True), | |
| gr.update(visible=True), gr.update(visible=True), | |
| ) | |
| def load_session_cb(selected_label): | |
| sid = label_to_id(selected_label) | |
| if not sid: return [], gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) | |
| return get_messages(sid), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) | |
| def delete_chat_cb(selected_label): | |
| sid = label_to_id(selected_label) | |
| if sid: delete_session(sid) | |
| labels, _ = list_sessions() | |
| selected = labels[0] if labels else None | |
| visible = bool(selected) | |
| return gr.update(choices=labels, value=selected), [], gr.update(visible=visible), gr.update(visible=visible), gr.update(visible=visible), gr.update(visible=visible) | |
| # --- Send --- | |
| def send_cb(user_text, selected_label, chatbot_msgs, system_message, max_tokens, temperature, top_p, model_choice, dataset_choice, *args): | |
| sid = label_to_id(selected_label) or create_session("New chat") | |
| add_message(sid, "user", user_text) | |
| update_session_title_if_needed(sid, user_text) | |
| display_msgs = chatbot_msgs[:] + [{"role": "user", "content": user_text}] | |
| if is_frontend_request(user_text): | |
| apology = "⚠️ I'm a backend-focused assistant and cannot provide frontend code." | |
| display_msgs.append({"role": "assistant", "content": apology}) | |
| add_message(sid, "assistant", apology) | |
| yield display_msgs, "", selected_label | |
| return | |
| display_msgs.append({"role": "assistant", "content": "…"}) | |
| yield display_msgs, "", selected_label | |
| client = get_client(model_choice) | |
| api_messages = build_api_messages(sid, system_message) | |
| partial = "" | |
| try: | |
| for chunk in client.chat_completion(messages=api_messages, max_tokens=int(max_tokens), temperature=float(temperature), top_p=float(top_p), stream=True): | |
| if not hasattr(chunk, "choices") or not chunk.choices: continue | |
| choice = chunk.choices[0] | |
| delta = getattr(getattr(choice, "delta", None), "content", None) or getattr(getattr(choice, "message", None), "content", None) or "" | |
| if delta: | |
| partial += delta | |
| display_msgs[-1]["content"] = partial | |
| yield display_msgs, "", selected_label | |
| add_message(sid, "assistant", partial) | |
| except Exception as e: | |
| display_msgs[-1]["content"] = f"⚠️ Error: {e}" | |
| yield display_msgs, "", selected_label | |
| # --- Regenerate --- | |
| def regenerate_cb(selected_label, system_message, max_tokens, temperature, top_p, model_choice, dataset_choice): | |
| sid = label_to_id(selected_label) | |
| if not sid: return [], "" | |
| msgs = get_messages(sid) | |
| if msgs and msgs[-1]["role"] == "assistant": | |
| conn = db() | |
| cur = conn.cursor() | |
| cur.execute("DELETE FROM messages WHERE id=(SELECT id FROM messages WHERE session_id=? ORDER BY id DESC LIMIT 1)", (sid,)) | |
| conn.commit() | |
| conn.close() | |
| msgs = get_messages(sid) | |
| display_msgs = msgs + [{"role": "assistant", "content": ""}] | |
| client = get_client(model_choice) | |
| partial = "" | |
| try: | |
| for chunk in client.chat_completion(messages=[{"role": "system", "content": system_message.strip()}]+msgs, max_tokens=int(max_tokens), temperature=float(temperature), top_p=float(top_p), stream=True): | |
| if not hasattr(chunk, "choices") or not chunk.choices: continue | |
| delta = getattr(getattr(chunk.choices[0], "delta", None), "content", None) or getattr(getattr(chunk.choices[0], "message", None), "content", None) or "" | |
| if delta: | |
| partial += delta | |
| display_msgs[-1]["content"] = partial | |
| yield display_msgs | |
| add_message(sid, "assistant", partial) | |
| except Exception as e: | |
| display_msgs[-1]["content"] = f"⚠️ Error: {e}" | |
| yield display_msgs, "", selected_label | |
| # --------------------------- | |
| # App UI | |
| # --------------------------- | |
| init_db() | |
| labels, _ = list_sessions() | |
| if not labels: create_session("New chat"); labels, _ = list_sessions() | |
| default_selected = labels[0] if labels else None | |
| with gr.Blocks(title="Backend-Focused LLaMA/Mistral CRUD Assistant", theme=gr.themes.Soft()) as demo: | |
| # --- Custom CSS --- | |
| gr.HTML(""" | |
| <style> | |
| /* ========== Compact section headings ========== */ | |
| .compact-heading h3 { | |
| font-size: 0.9rem !important; | |
| font-weight: 600 !important; | |
| margin: 0.3rem 0 !important; | |
| padding: 0 !important; | |
| } | |
| /* ========== Main top title smaller ========== */ | |
| .main-title h2 { | |
| font-size: 1.2rem !important; | |
| font-weight: 700 !important; | |
| margin-bottom: 0.6rem !important; | |
| } | |
| /* ========== Reduce spacing below sliders & dropdowns ========== */ | |
| .compact-sliders .gr-slider, | |
| .compact-sliders .gr-number, | |
| .compact-sliders .gr-dropdown { | |
| margin-bottom: 0.4rem !important; | |
| } | |
| /* ========== Compact the left panel column ========== */ | |
| .left-panel .gr-column { | |
| gap: 0.3rem !important; | |
| } | |
| /* ========== Tiny / pressed buttons ========== */ | |
| .tiny-btn .gr-button { | |
| font-size: 0.75rem !important; /* smaller text */ | |
| padding: 2px 6px !important; /* tight padding */ | |
| min-height: 24px !important; /* shorter button */ | |
| line-height: 1.1 !important; /* compact vertical spacing */ | |
| border-radius: 4px !important; /* slightly rounded */ | |
| } | |
| /* ========== Remove extra top/left page gaps ========== */ | |
| body, html { | |
| margin: 0 !important; | |
| padding: 0 !important; | |
| } | |
| .gradio-container { | |
| margin-top: -40px !important; /* was -10px, move up more */ | |
| margin-left: -40px !important; | |
| padding-top: 0 !important; | |
| padding-left: 0 !important; | |
| } | |
| .main-title, | |
| .gr-block:first-of-type { | |
| margin-top: -10px !important; /* was -5px, pull up further */ | |
| margin-left: -10 !important; | |
| padding-top: 0 !important; | |
| padding-left: 0 !important; | |
| } | |
| /* ========== Compact sliders: reduce vertical height and spacing ========== */ | |
| .compact-sliders .gr-slider { | |
| height: 28px !important; /* shorter track */ | |
| padding: 0 !important; /* remove inner padding */ | |
| margin-bottom: 0.3rem !important; /* less spacing below each slider */ | |
| } | |
| /* Compact slider label and value text */ | |
| .compact-sliders .gr-slider label, | |
| .compact-sliders .gr-slider span { | |
| font-size: 0.8rem !important; /* smaller text */ | |
| } | |
| /* Tighten the numeric input box on sliders */ | |
| .compact-sliders input[type="number"] { | |
| max-width: 70px !important; /* smaller numeric input width */ | |
| font-size: 0.8rem !important; | |
| padding: 2px 4px !important; | |
| } | |
| </style> | |
| """) | |
| # gr.Markdown("## 🗄️ LLaMA & Mistral Backend-Focused CRUD Automation — Persistent History") | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1, min_width=260): | |
| # gr.Markdown("### 📁 Sessions") | |
| session_list = gr.Radio(choices=labels, value=default_selected, label="Your Chats", interactive=True) | |
| with gr.Row(elem_id="session-actions-row"): | |
| new_btn = gr.Button("New Chat") | |
| rename_btn = gr.Button("Rename", visible=False) | |
| save_title_btn = gr.Button("Save", visible=False) | |
| del_btn = gr.Button("Delete", visible=False) | |
| refresh_btn = gr.Button("Refresh", visible=False) | |
| edit_title_box = gr.Textbox(label="Edit Chat Name", placeholder="Type new chat name…", visible=False) | |
| # Model selection | |
| # gr.Markdown("### 🤖 Model Selection" ,elem_classes="compact-heading") | |
| model_choice = gr.Dropdown(choices=list(MODELS.keys()), value=list(MODELS.keys())[0], label="Choose a model", interactive=True) | |
| # Dataset selection | |
| # gr.Markdown("### 📚 Dataset Selection" ,elem_classes="compact-heading") | |
| dataset_choice = gr.Dropdown(choices=DATASETS, value=DATASETS[0], label="Select a dataset", interactive=True) | |
| # Generation settings | |
| # gr.Markdown("### ⚙️ Generation Settings" ,elem_classes="compact-heading") | |
| with gr.Group(elem_classes="system-message-compact"): | |
| system_box = gr.Textbox(value=SYSTEM_DEFAULT, label="System message", lines=1) | |
| with gr.Group(elem_classes="compact-sliders"): | |
| max_tokens = gr.Slider(256, 4096, value=1200, step=16, label="Max tokens") | |
| temperature = gr.Slider(0.0, 2.0, value=0.25, step=0.05, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot(label="Assistant", type="messages") | |
| with gr.Row(elem_classes="user-input-row"): | |
| user_box = gr.Textbox(placeholder="Describe your CRUD/backend task…", lines=3, scale=5) | |
| send_btn = gr.Button("▶️", variant="primary", scale=1) | |
| regen_btn = gr.Button("🔁", variant="secondary", scale=1) | |
| # --- Wire callbacks --- | |
| refresh_btn.click(refresh_sessions_cb, outputs=[session_list, edit_title_box, save_title_btn, del_btn, refresh_btn]) | |
| new_btn.click(new_chat_cb, outputs=[session_list, chatbot, user_box, edit_title_box, save_title_btn, del_btn, refresh_btn]) | |
| del_btn.click(delete_chat_cb, inputs=session_list, outputs=[session_list, chatbot, edit_title_box, save_title_btn, del_btn, refresh_btn]) | |
| session_list.change(load_session_cb, inputs=session_list, outputs=[chatbot, edit_title_box, save_title_btn, del_btn, refresh_btn]) | |
| send_btn.click(send_cb, inputs=[user_box, session_list, chatbot, system_box, max_tokens, temperature, top_p, model_choice, dataset_choice], outputs=[chatbot, user_box, session_list]) | |
| user_box.submit(send_cb, inputs=[user_box, session_list, chatbot, system_box, max_tokens, temperature, top_p, model_choice, dataset_choice], outputs=[chatbot, user_box, session_list]) | |
| regen_btn.click(regenerate_cb, inputs=[session_list, system_box, max_tokens, temperature, top_p, model_choice, dataset_choice], outputs=chatbot) | |
| if __name__ == "__main__": | |
| demo.launch() | |