WebashalarForML's picture
Update app.py
4807519 verified
raw
history blame
11.5 kB
#!/usr/bin/env python3
# filename: app_refactored.py
import os
import json
import logging
import re
from pathlib import Path
from typing import Dict, Any, List, Optional, Tuple
from flask import Flask, request, jsonify
from flask_cors import CORS
from dotenv import load_dotenv
# Replace with your LLM client import; kept generic here.
# from langchain_groq import ChatGroq
# === Config ===
load_dotenv()
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
if not GROQ_API_KEY:
raise RuntimeError("GROQ_API_KEY not set in environment")
LLM_MODEL = os.getenv("LLM_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct")
LLM_TIMEOUT_SECONDS = float(os.getenv("LLM_TIMEOUT_SECONDS", "20"))
MAX_HISTORY_MESSAGES = int(os.getenv("MAX_HISTORY_MESSAGES", "12"))
VALID_LANGUAGES = {"python", "javascript", "java", "c++", "c#", "go", "ruby", "php", "typescript", "swift"}
# === Logging ===
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger("code-assistant")
# === LLM client (example) ===
# NOTE: adapt this block to match your SDK. Keep a tolerant accessor for response text.
class DummyLLM:
def __init__(self, **kwargs):
self.kwargs = kwargs
def invoke(self, messages: List[Dict[str, str]], timeout: Optional[float] = None):
# stub: replace with real client call
class Resp: pass
r = Resp()
r.content = json.dumps({
"assistant_reply": "This is a dummy reply. Replace with real LLM client.",
"code_snippet": "",
"state_updates": {"conversationSummary": "dummy", "language": "Python"},
"suggested_tags": ["example"]
})
return r
# llm = ChatGroq(model=LLM_MODEL, api_key=GROQ_API_KEY, temperature=0.1, max_tokens=2048)
llm = DummyLLM(model=LLM_MODEL, api_key=GROQ_API_KEY) # replace with real client
# === Prompt ===
SYSTEM_PROMPT = (
"You are an expert programming assistant. Prefer to return a JSON object with keys: "
"assistant_reply (string), code_snippet (string, optional, can be multiline), "
"state_updates (object), suggested_tags (array). If including code, put it in triple backticks. "
"Do NOT escape newlines in code_snippet; return natural multi-line strings."
)
# === Utilities ===
def clamp_summary(s: str, max_len: int = 1200) -> str:
s = (s or "").strip()
return s if len(s) <= max_len else s[:max_len-3] + "..."
def canonicalize_language(text: Optional[str]) -> Optional[str]:
if not text:
return None
t = text.strip().lower()
# quick membership test
for lang in VALID_LANGUAGES:
if lang in t or t == lang:
return lang
return None
def try_parse_json(s: str) -> Optional[Dict[str, Any]]:
try:
return json.loads(s)
except Exception:
return None
def extract_code_fence(text: str) -> Optional[str]:
m = re.search(r"```(?:[a-zA-Z0-9_+\-]*)\n([\s\S]*?)```", text)
return m.group(1).strip() if m else None
def parse_llm_output(raw: str) -> Dict[str, Any]:
"""
Tolerant multi-strategy parser:
1) Direct JSON
2) JSON inside a ```json``` fence
3) Heuristic extraction: assistant_reply lines, code fences for code_snippet, simple state_updates line (json)
"""
default = {
"assistant_reply": "I couldn't parse the model response. Please rephrase or simplify the request.",
"code_snippet": "",
"state_updates": {"conversationSummary": "", "language": "python"},
"suggested_tags": [],
"parse_ok": False,
}
if not raw or not isinstance(raw, str):
return default
raw = raw.strip()
# 1) direct JSON
parsed = try_parse_json(raw)
if parsed and isinstance(parsed, dict) and "assistant_reply" in parsed:
parsed.setdefault("code_snippet", "")
parsed.setdefault("state_updates", {})
parsed.setdefault("suggested_tags", [])
parsed["parse_ok"] = True
return parsed
# 2) JSON inside any code fence (```json ... ```)
m_json_fence = re.search(r"```json\s*([\s\S]*?)```", raw, re.IGNORECASE)
if m_json_fence:
candidate = m_json_fence.group(1)
parsed = try_parse_json(candidate)
if parsed and "assistant_reply" in parsed:
parsed.setdefault("code_snippet", "")
parsed.setdefault("state_updates", {})
parsed.setdefault("suggested_tags", [])
parsed["parse_ok"] = True
return parsed
# 3) Heuristics: find assistant_reply: ...; code fence for code; state_updates as inline JSON
assistant_reply = ""
code_snippet = ""
state_updates = {}
suggested_tags = []
# a) extract code fence (first code block)
code_snippet = extract_code_fence(raw) or ""
# b) extract assistant_reply by looking for lines like "assistant_reply:" or markdown bold
m = re.search(r'assistant_reply\s*[:\-]\s*(["\']?)([\s\S]*?)(?=\n[a-z_]+[\s\-:]{1}|$)', raw, re.IGNORECASE)
if m:
assistant_reply = m.group(2).strip()
else:
# fallback: take everything up to the first code fence or up to "state_updates"
cut_idx = raw.find("```")
state_idx = raw.lower().find("state_updates")
end = min([i for i in (cut_idx if cut_idx>=0 else len(raw), state_idx if state_idx>=0 else len(raw))])
assistant_reply = raw[:end].strip()
# strip any leading labels like "**assistant_reply**:" or similar
assistant_reply = re.sub(r'^\**\s*assistant_reply\**\s*[:\-]?\s*', '', assistant_reply, flags=re.IGNORECASE).strip()
# c) find state_updates JSON block if present
m_state = re.search(r"state_updates\s*[:\-]?\s*(\{[\s\S]*?\})", raw, re.IGNORECASE)
if m_state:
try:
state_updates = json.loads(m_state.group(1))
except Exception:
state_updates = {}
# d) suggested_tags simple extract
m_tags = re.search(r"suggested_tags\s*[:\-]?\s*(\[[^\]]*\])", raw, re.IGNORECASE)
if m_tags:
try:
suggested_tags = json.loads(m_tags.group(1))
except Exception:
suggested_tags = []
result = {
"assistant_reply": assistant_reply or default["assistant_reply"],
"code_snippet": code_snippet or "",
"state_updates": state_updates or {"conversationSummary": "", "language": "python"},
"suggested_tags": suggested_tags or [],
"parse_ok": bool(assistant_reply or code_snippet),
}
return result
# === Flask app ===
BASE_DIR = Path(__file__).resolve().parent
app = Flask(__name__, static_folder=str(BASE_DIR / "static"), static_url_path="/static")
CORS(app)
@app.route("/", methods=["GET"])
def serve_frontend():
try:
return app.send_static_file("frontend.html")
except Exception:
return "<h3>frontend.html not found in static/ — please add your frontend.html there.</h3>", 404
@app.route("/chat", methods=["POST"])
def chat():
payload = request.get_json(force=True, silent=True)
if not isinstance(payload, dict):
return jsonify({"error": "invalid request body"}), 400
chat_history = payload.get("chat_history", [])
assistant_state = payload.get("assistant_state", {})
# validate/normalize assistant_state
state = {
"conversationSummary": assistant_state.get("conversationSummary", "").strip(),
"language": assistant_state.get("language", "python").strip().lower(),
"taggedReplies": assistant_state.get("taggedReplies", []),
}
# limit history length to recent messages to control token usage
if isinstance(chat_history, list) and len(chat_history) > MAX_HISTORY_MESSAGES:
chat_history = chat_history[-MAX_HISTORY_MESSAGES:]
# build messages for LLM (do not mutate user's last message)
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
for m in chat_history:
if not isinstance(m, dict):
continue
role = m.get("role")
content = m.get("content")
if role in ("user", "assistant") and content:
messages.append({"role": role, "content": content})
# append a supplemental context message (do not overwrite)
context_hint = f"[CONTEXT] language={state['language']} summary={clamp_summary(state['conversationSummary'], 300)}"
messages.append({"role": "system", "content": context_hint})
# call LLM (wrap in try/except)
try:
raw_resp = llm.invoke(messages, timeout=LLM_TIMEOUT_SECONDS)
# tolerate different shapes
raw_text = getattr(raw_resp, "content", None) or getattr(raw_resp, "text", None) or str(raw_resp)
logger.info("LLM raw text: %.300s", raw_text.replace('\n', ' ')[:300])
except Exception as e:
logger.exception("LLM invocation error")
return jsonify({"error": "LLM invocation failed", "detail": str(e)}), 500
parsed = parse_llm_output(raw_text)
# If parse failed, don't overwrite the existing state; give helpful message.
if not parsed.get("parse_ok"):
logger.warning("Parse failure. Returning fallback message.")
return jsonify({
"assistant_reply": parsed["assistant_reply"],
"code_snippet": "",
"updated_state": state,
"suggested_tags": [],
"parse_ok": False,
}), 200
# Validate and apply state_updates conservatively
updates = parsed.get("state_updates", {}) or {}
if isinstance(updates, dict):
if "conversationSummary" in updates:
state["conversationSummary"] = clamp_summary(str(updates["conversationSummary"]))
if "language" in updates:
lang = canonicalize_language(str(updates["language"]))
if lang:
state["language"] = lang
# limit suggested tags
tags = parsed.get("suggested_tags", []) or []
if isinstance(tags, list):
tags = [str(t).strip() for t in tags if t and isinstance(t, (str,))]
tags = tags[:3]
return jsonify({
"assistant_reply": parsed.get("assistant_reply", ""),
"code_snippet": parsed.get("code_snippet", ""),
"updated_state": state,
"suggested_tags": tags,
"parse_ok": True,
}), 200
@app.route("/tag_reply", methods=["POST"])
def tag_reply():
data = request.get_json(force=True, silent=True)
if not isinstance(data, dict):
return jsonify({"error": "invalid request body"}), 400
reply_content = data.get("reply")
tags = data.get("tags", [])
if not reply_content or not tags or not isinstance(tags, list):
return jsonify({"error": "Missing 'reply' or 'tags' in request"}), 400
tags_clean = [str(t).strip().lower() for t in tags if re.match(r'^[\w\-]{1,30}$', str(t).strip())]
if not tags_clean:
return jsonify({"error": "No valid tags provided"}), 400
assistant_state = data.get("assistant_state", {})
state = {
"conversationSummary": assistant_state.get("conversationSummary", ""),
"language": assistant_state.get("language", "python"),
"taggedReplies": assistant_state.get("taggedReplies", []),
}
state["taggedReplies"].append({"reply": reply_content, "tags": tags_clean})
logger.info("Tagged reply saved: %s", tags_clean)
return jsonify({"message": "Reply saved", "updated_state": state}), 200
@app.route("/ping", methods=["GET"])
def ping():
return jsonify({"status": "ok"})
if __name__ == "__main__":
port = int(os.getenv("PORT", "7860"))
app.run(host="0.0.0.0", port=port, debug=True)