admaker / api /server.py
karthikeya1212's picture
Update api/server.py
0050d71 verified
raw
history blame
2.7 kB
import logging
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from services import queue_manager
import os
from pathlib import Path
# CACHE PATCH BLOCK: place FIRST in pipeline.py!
HF_CACHE_DIR = Path("/tmp/hf_cache")
HF_CACHE_DIR.mkdir(parents=True, exist_ok=True)
os.environ.update({
"HF_HOME": str(HF_CACHE_DIR),
"HF_HUB_CACHE": str(HF_CACHE_DIR),
"DIFFUSERS_CACHE": str(HF_CACHE_DIR),
"TRANSFORMERS_CACHE": str(HF_CACHE_DIR),
"XDG_CACHE_HOME": str(HF_CACHE_DIR),
"HF_DATASETS_CACHE": str(HF_CACHE_DIR),
"HF_MODULES_CACHE": str(HF_CACHE_DIR),
"TMPDIR": str(HF_CACHE_DIR),
"CACHE_DIR": str(HF_CACHE_DIR),
"TORCH_HOME": str(HF_CACHE_DIR),
"HOME": str(HF_CACHE_DIR)
})
import os.path
if not hasattr(os.path, "expanduser_original"):
os.path.expanduser_original = os.path.expanduser
def safe_expanduser(path):
if (
path.startswith("~") or
path.startswith("/.cache") or
path.startswith("/root/.cache")
):
return str(HF_CACHE_DIR)
return os.path.expanduser_original(path)
os.path.expanduser = safe_expanduser
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
app = FastAPI(title="AI ADD Generator", version="1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ---------------------------
# Pydantic models
# ---------------------------
class IdeaRequest(BaseModel):
idea: str
class ConfirmationRequest(BaseModel):
task_id: str
confirm: bool
# ---------------------------
# API endpoints
# ---------------------------
@app.post("/submit_idea")
async def submit_idea(request: IdeaRequest):
task_id = await queue_manager.add_task(request.idea)
return {"status": "submitted", "task_id": task_id}
@app.post("/confirm")
async def confirm_task(request: ConfirmationRequest):
task = queue_manager.get_task_status(request.task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
if task["status"] != queue_manager.TaskStatus.WAITING_CONFIRMATION:
raise HTTPException(status_code=400, detail="Task not waiting for confirmation")
await queue_manager.confirm_task(request.task_id)
return {"status": "confirmed", "task": task}
@app.get("/status/{task_id}")
async def status(task_id: str):
task = queue_manager.get_task_status(task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return task
@app.get("/")
async def health():
return {"status": "running"}