Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import time | |
| from typing import Any, Dict, List, Optional | |
| from fastapi import FastAPI, HTTPException, Request, WebSocket, WebSocketDisconnect | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| from huggingface_hub import InferenceClient | |
| from pydantic import BaseModel | |
| app = FastAPI( | |
| title="Sheikh LLM Studio", | |
| description="Advanced LLM platform with chat, tools, and model workflows", | |
| version="2.0.0", | |
| ) | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| STATIC_DIR = "static" | |
| TEMPLATES_DIR = "templates" | |
| app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") | |
| templates = Jinja2Templates(directory=TEMPLATES_DIR) | |
| class Config: | |
| HF_TOKEN: Optional[str] = os.getenv("HF_TOKEN") | |
| AVAILABLE_MODELS: Dict[str, str] = { | |
| "mistral-small": "mistralai/Mistral-Small-3.1-24B-Instruct-2503", | |
| "mistral-large": "mistralai/Mistral-Large-Instruct-2411", | |
| "mistral-7b": "mistralai/Mistral-7B-Instruct-v0.3", | |
| "baby-grok": "IntelligentEstate/Baby_Grok3-1.5b-iQ4_K_M-GGUF", | |
| } | |
| class ChatRequest(BaseModel): | |
| message: str | |
| model: str = "mistral-small" | |
| max_tokens: int = 500 | |
| temperature: float = 0.7 | |
| stream: bool = False | |
| class ChatResponse(BaseModel): | |
| response: str | |
| model: str | |
| status: str | |
| class ToolRequest(BaseModel): | |
| tool: str | |
| parameters: Dict[str, Any] | |
| class ModelConfig(BaseModel): | |
| base_model: str | |
| dataset_path: str | |
| training_config: Dict[str, Any] | |
| connected_clients: List[WebSocket] = [] | |
| async def startup_event() -> None: | |
| logger.info("Starting Sheikh LLM Studio") | |
| if not Config.HF_TOKEN: | |
| logger.warning("HF_TOKEN not set; gated models will not be accessible.") | |
| async def home(request: Request) -> HTMLResponse: | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| async def chat_interface(request: Request) -> HTMLResponse: | |
| return templates.TemplateResponse("chat.html", {"request": request}) | |
| async def model_studio(request: Request) -> HTMLResponse: | |
| return templates.TemplateResponse("studio.html", {"request": request}) | |
| async def get_available_models() -> Dict[str, Any]: | |
| return {"models": Config.AVAILABLE_MODELS, "status": "success"} | |
| async def chat_completion(request: ChatRequest) -> ChatResponse: | |
| if not request.message.strip(): | |
| raise HTTPException(status_code=400, detail="Message cannot be empty") | |
| if request.model not in Config.AVAILABLE_MODELS: | |
| raise HTTPException(status_code=400, detail="Unknown model selection") | |
| if Config.HF_TOKEN is None: | |
| raise HTTPException(status_code=500, detail="HF_TOKEN environment variable is not set") | |
| model_id = Config.AVAILABLE_MODELS[request.model] | |
| client = InferenceClient(model=model_id, token=Config.HF_TOKEN) | |
| prompt = request.message | |
| if "mistral" in request.model: | |
| prompt = f"<s>[INST] {request.message.strip()} [/INST]" | |
| try: | |
| if request.stream: | |
| generated_text = "" | |
| for chunk in client.text_generation( | |
| prompt, | |
| max_new_tokens=request.max_tokens, | |
| temperature=request.temperature, | |
| stream=True, | |
| ): | |
| generated_text += getattr(chunk, "token", "") | |
| await asyncio.sleep(0) | |
| else: | |
| generated_text = client.text_generation( | |
| prompt, | |
| max_new_tokens=request.max_tokens, | |
| temperature=request.temperature, | |
| ) | |
| except Exception as exc: # pragma: no cover - external service | |
| logger.error("Chat generation failed: %s", exc) | |
| raise HTTPException(status_code=502, detail=f"Model error: {exc}") from exc | |
| return ChatResponse(response=generated_text, model=request.model, status="success") | |
| async def websocket_chat(websocket: WebSocket) -> None: | |
| await websocket.accept() | |
| connected_clients.append(websocket) | |
| try: | |
| while True: | |
| data = await websocket.receive_text() | |
| message_data = json.loads(data) | |
| user_message = message_data.get("message", "") | |
| response_text = f"Echo: {user_message}" | |
| for index in range(1, len(response_text) + 1): | |
| await websocket.send_text(json.dumps({"chunk": response_text[:index], "done": False})) | |
| await asyncio.sleep(0.1) | |
| await websocket.send_text(json.dumps({"chunk": response_text, "done": True})) | |
| except WebSocketDisconnect: | |
| connected_clients.remove(websocket) | |
| async def search_tool(request: ToolRequest) -> Dict[str, Any]: | |
| if request.tool != "web_search": | |
| raise HTTPException(status_code=400, detail="Unknown tool") | |
| query = request.parameters.get("query", "") | |
| return { | |
| "tool": "web_search", | |
| "results": [ | |
| {"title": f"Result 1 for {query}", "url": "#"}, | |
| {"title": f"Result 2 for {query}", "url": "#"}, | |
| ], | |
| "status": "success", | |
| } | |
| async def code_tool(request: ToolRequest) -> Dict[str, Any]: | |
| if request.tool != "execute_python": | |
| raise HTTPException(status_code=400, detail="Unknown tool") | |
| code = request.parameters.get("code", "") | |
| return { | |
| "tool": "execute_python", | |
| "output": f"Executed code: {code}", | |
| "status": "success", | |
| } | |
| async def create_model(config: ModelConfig) -> Dict[str, Any]: | |
| job_id = f"train_{int(time.time())}" | |
| training_job = { | |
| "job_id": job_id, | |
| "status": "queued", | |
| "base_model": config.base_model, | |
| "dataset_path": config.dataset_path, | |
| "config": config.training_config, | |
| } | |
| return { | |
| "job": training_job, | |
| "message": "Training job queued successfully", | |
| "status": "success", | |
| } | |
| async def get_training_job(job_id: str) -> Dict[str, Any]: | |
| return { | |
| "job_id": job_id, | |
| "status": "completed", | |
| "progress": 100, | |
| "model_url": f"https://huggingface.co/RecentCoders/{job_id}", | |
| } | |
| async def health_check() -> Dict[str, Any]: | |
| return { | |
| "status": "healthy", | |
| "service": "sheikh-llm-studio", | |
| "version": "2.0.0", | |
| "features": ["chat", "tools", "model_studio", "websockets"], | |
| } | |
| if __name__ == "__main__": # pragma: no cover | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |