sheikh-llm / app.py
root
Revamp app with chat studio and tooling
94c8770
from __future__ import annotations
import asyncio
import json
import logging
import os
import time
from typing import Any, Dict, List, Optional
from fastapi import FastAPI, HTTPException, Request, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from huggingface_hub import InferenceClient
from pydantic import BaseModel
app = FastAPI(
title="Sheikh LLM Studio",
description="Advanced LLM platform with chat, tools, and model workflows",
version="2.0.0",
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
STATIC_DIR = "static"
TEMPLATES_DIR = "templates"
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
templates = Jinja2Templates(directory=TEMPLATES_DIR)
class Config:
HF_TOKEN: Optional[str] = os.getenv("HF_TOKEN")
AVAILABLE_MODELS: Dict[str, str] = {
"mistral-small": "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
"mistral-large": "mistralai/Mistral-Large-Instruct-2411",
"mistral-7b": "mistralai/Mistral-7B-Instruct-v0.3",
"baby-grok": "IntelligentEstate/Baby_Grok3-1.5b-iQ4_K_M-GGUF",
}
class ChatRequest(BaseModel):
message: str
model: str = "mistral-small"
max_tokens: int = 500
temperature: float = 0.7
stream: bool = False
class ChatResponse(BaseModel):
response: str
model: str
status: str
class ToolRequest(BaseModel):
tool: str
parameters: Dict[str, Any]
class ModelConfig(BaseModel):
base_model: str
dataset_path: str
training_config: Dict[str, Any]
connected_clients: List[WebSocket] = []
@app.on_event("startup")
async def startup_event() -> None:
logger.info("Starting Sheikh LLM Studio")
if not Config.HF_TOKEN:
logger.warning("HF_TOKEN not set; gated models will not be accessible.")
@app.get("/", response_class=HTMLResponse)
async def home(request: Request) -> HTMLResponse:
return templates.TemplateResponse("index.html", {"request": request})
@app.get("/chat", response_class=HTMLResponse)
async def chat_interface(request: Request) -> HTMLResponse:
return templates.TemplateResponse("chat.html", {"request": request})
@app.get("/studio", response_class=HTMLResponse)
async def model_studio(request: Request) -> HTMLResponse:
return templates.TemplateResponse("studio.html", {"request": request})
@app.get("/api/models")
async def get_available_models() -> Dict[str, Any]:
return {"models": Config.AVAILABLE_MODELS, "status": "success"}
@app.post("/api/chat", response_model=ChatResponse)
async def chat_completion(request: ChatRequest) -> ChatResponse:
if not request.message.strip():
raise HTTPException(status_code=400, detail="Message cannot be empty")
if request.model not in Config.AVAILABLE_MODELS:
raise HTTPException(status_code=400, detail="Unknown model selection")
if Config.HF_TOKEN is None:
raise HTTPException(status_code=500, detail="HF_TOKEN environment variable is not set")
model_id = Config.AVAILABLE_MODELS[request.model]
client = InferenceClient(model=model_id, token=Config.HF_TOKEN)
prompt = request.message
if "mistral" in request.model:
prompt = f"<s>[INST] {request.message.strip()} [/INST]"
try:
if request.stream:
generated_text = ""
for chunk in client.text_generation(
prompt,
max_new_tokens=request.max_tokens,
temperature=request.temperature,
stream=True,
):
generated_text += getattr(chunk, "token", "")
await asyncio.sleep(0)
else:
generated_text = client.text_generation(
prompt,
max_new_tokens=request.max_tokens,
temperature=request.temperature,
)
except Exception as exc: # pragma: no cover - external service
logger.error("Chat generation failed: %s", exc)
raise HTTPException(status_code=502, detail=f"Model error: {exc}") from exc
return ChatResponse(response=generated_text, model=request.model, status="success")
@app.websocket("/ws/chat")
async def websocket_chat(websocket: WebSocket) -> None:
await websocket.accept()
connected_clients.append(websocket)
try:
while True:
data = await websocket.receive_text()
message_data = json.loads(data)
user_message = message_data.get("message", "")
response_text = f"Echo: {user_message}"
for index in range(1, len(response_text) + 1):
await websocket.send_text(json.dumps({"chunk": response_text[:index], "done": False}))
await asyncio.sleep(0.1)
await websocket.send_text(json.dumps({"chunk": response_text, "done": True}))
except WebSocketDisconnect:
connected_clients.remove(websocket)
@app.post("/api/tools/search")
async def search_tool(request: ToolRequest) -> Dict[str, Any]:
if request.tool != "web_search":
raise HTTPException(status_code=400, detail="Unknown tool")
query = request.parameters.get("query", "")
return {
"tool": "web_search",
"results": [
{"title": f"Result 1 for {query}", "url": "#"},
{"title": f"Result 2 for {query}", "url": "#"},
],
"status": "success",
}
@app.post("/api/tools/code")
async def code_tool(request: ToolRequest) -> Dict[str, Any]:
if request.tool != "execute_python":
raise HTTPException(status_code=400, detail="Unknown tool")
code = request.parameters.get("code", "")
return {
"tool": "execute_python",
"output": f"Executed code: {code}",
"status": "success",
}
@app.post("/api/studio/create-model")
async def create_model(config: ModelConfig) -> Dict[str, Any]:
job_id = f"train_{int(time.time())}"
training_job = {
"job_id": job_id,
"status": "queued",
"base_model": config.base_model,
"dataset_path": config.dataset_path,
"config": config.training_config,
}
return {
"job": training_job,
"message": "Training job queued successfully",
"status": "success",
}
@app.get("/api/studio/jobs/{job_id}")
async def get_training_job(job_id: str) -> Dict[str, Any]:
return {
"job_id": job_id,
"status": "completed",
"progress": 100,
"model_url": f"https://huggingface.co/RecentCoders/{job_id}",
}
@app.get("/health")
async def health_check() -> Dict[str, Any]:
return {
"status": "healthy",
"service": "sheikh-llm-studio",
"version": "2.0.0",
"features": ["chat", "tools", "model_studio", "websockets"],
}
if __name__ == "__main__": # pragma: no cover
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)