from __future__ import annotations import asyncio import json import logging import os import time from typing import Any, Dict, List, Optional from fastapi import FastAPI, HTTPException, Request, WebSocket, WebSocketDisconnect from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from huggingface_hub import InferenceClient from pydantic import BaseModel app = FastAPI( title="Sheikh LLM Studio", description="Advanced LLM platform with chat, tools, and model workflows", version="2.0.0", ) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) STATIC_DIR = "static" TEMPLATES_DIR = "templates" app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") templates = Jinja2Templates(directory=TEMPLATES_DIR) class Config: HF_TOKEN: Optional[str] = os.getenv("HF_TOKEN") AVAILABLE_MODELS: Dict[str, str] = { "mistral-small": "mistralai/Mistral-Small-3.1-24B-Instruct-2503", "mistral-large": "mistralai/Mistral-Large-Instruct-2411", "mistral-7b": "mistralai/Mistral-7B-Instruct-v0.3", "baby-grok": "IntelligentEstate/Baby_Grok3-1.5b-iQ4_K_M-GGUF", } class ChatRequest(BaseModel): message: str model: str = "mistral-small" max_tokens: int = 500 temperature: float = 0.7 stream: bool = False class ChatResponse(BaseModel): response: str model: str status: str class ToolRequest(BaseModel): tool: str parameters: Dict[str, Any] class ModelConfig(BaseModel): base_model: str dataset_path: str training_config: Dict[str, Any] connected_clients: List[WebSocket] = [] @app.on_event("startup") async def startup_event() -> None: logger.info("Starting Sheikh LLM Studio") if not Config.HF_TOKEN: logger.warning("HF_TOKEN not set; gated models will not be accessible.") @app.get("/", response_class=HTMLResponse) async def home(request: Request) -> HTMLResponse: return templates.TemplateResponse("index.html", {"request": request}) @app.get("/chat", response_class=HTMLResponse) async def chat_interface(request: Request) -> HTMLResponse: return templates.TemplateResponse("chat.html", {"request": request}) @app.get("/studio", response_class=HTMLResponse) async def model_studio(request: Request) -> HTMLResponse: return templates.TemplateResponse("studio.html", {"request": request}) @app.get("/api/models") async def get_available_models() -> Dict[str, Any]: return {"models": Config.AVAILABLE_MODELS, "status": "success"} @app.post("/api/chat", response_model=ChatResponse) async def chat_completion(request: ChatRequest) -> ChatResponse: if not request.message.strip(): raise HTTPException(status_code=400, detail="Message cannot be empty") if request.model not in Config.AVAILABLE_MODELS: raise HTTPException(status_code=400, detail="Unknown model selection") if Config.HF_TOKEN is None: raise HTTPException(status_code=500, detail="HF_TOKEN environment variable is not set") model_id = Config.AVAILABLE_MODELS[request.model] client = InferenceClient(model=model_id, token=Config.HF_TOKEN) prompt = request.message if "mistral" in request.model: prompt = f"[INST] {request.message.strip()} [/INST]" try: if request.stream: generated_text = "" for chunk in client.text_generation( prompt, max_new_tokens=request.max_tokens, temperature=request.temperature, stream=True, ): generated_text += getattr(chunk, "token", "") await asyncio.sleep(0) else: generated_text = client.text_generation( prompt, max_new_tokens=request.max_tokens, temperature=request.temperature, ) except Exception as exc: # pragma: no cover - external service logger.error("Chat generation failed: %s", exc) raise HTTPException(status_code=502, detail=f"Model error: {exc}") from exc return ChatResponse(response=generated_text, model=request.model, status="success") @app.websocket("/ws/chat") async def websocket_chat(websocket: WebSocket) -> None: await websocket.accept() connected_clients.append(websocket) try: while True: data = await websocket.receive_text() message_data = json.loads(data) user_message = message_data.get("message", "") response_text = f"Echo: {user_message}" for index in range(1, len(response_text) + 1): await websocket.send_text(json.dumps({"chunk": response_text[:index], "done": False})) await asyncio.sleep(0.1) await websocket.send_text(json.dumps({"chunk": response_text, "done": True})) except WebSocketDisconnect: connected_clients.remove(websocket) @app.post("/api/tools/search") async def search_tool(request: ToolRequest) -> Dict[str, Any]: if request.tool != "web_search": raise HTTPException(status_code=400, detail="Unknown tool") query = request.parameters.get("query", "") return { "tool": "web_search", "results": [ {"title": f"Result 1 for {query}", "url": "#"}, {"title": f"Result 2 for {query}", "url": "#"}, ], "status": "success", } @app.post("/api/tools/code") async def code_tool(request: ToolRequest) -> Dict[str, Any]: if request.tool != "execute_python": raise HTTPException(status_code=400, detail="Unknown tool") code = request.parameters.get("code", "") return { "tool": "execute_python", "output": f"Executed code: {code}", "status": "success", } @app.post("/api/studio/create-model") async def create_model(config: ModelConfig) -> Dict[str, Any]: job_id = f"train_{int(time.time())}" training_job = { "job_id": job_id, "status": "queued", "base_model": config.base_model, "dataset_path": config.dataset_path, "config": config.training_config, } return { "job": training_job, "message": "Training job queued successfully", "status": "success", } @app.get("/api/studio/jobs/{job_id}") async def get_training_job(job_id: str) -> Dict[str, Any]: return { "job_id": job_id, "status": "completed", "progress": 100, "model_url": f"https://huggingface.co/RecentCoders/{job_id}", } @app.get("/health") async def health_check() -> Dict[str, Any]: return { "status": "healthy", "service": "sheikh-llm-studio", "version": "2.0.0", "features": ["chat", "tools", "model_studio", "websockets"], } if __name__ == "__main__": # pragma: no cover import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)