File size: 2,700 Bytes
4a5d5fa
c5fb75b
4a5d5fa
c5fb75b
 
0050d71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328d050
c5fb75b
4a5d5fa
c5fb75b
4a5d5fa
 
 
 
 
 
 
 
 
c5fb75b
4a5d5fa
c5fb75b
4a5d5fa
 
 
 
 
 
 
c5fb75b
 
 
4a5d5fa
 
 
c5fb75b
4a5d5fa
 
 
c5fb75b
4a5d5fa
c5fb75b
4a5d5fa
c5fb75b
4a5d5fa
c5fb75b
 
4a5d5fa
 
c5fb75b
4a5d5fa
 
c5fb75b
4a5d5fa
 
 
c5fb75b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import logging
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from services import queue_manager
import os
from pathlib import Path

# CACHE PATCH BLOCK: place FIRST in pipeline.py!
HF_CACHE_DIR = Path("/tmp/hf_cache")
HF_CACHE_DIR.mkdir(parents=True, exist_ok=True)
os.environ.update({
    "HF_HOME": str(HF_CACHE_DIR),
    "HF_HUB_CACHE": str(HF_CACHE_DIR),
    "DIFFUSERS_CACHE": str(HF_CACHE_DIR),
    "TRANSFORMERS_CACHE": str(HF_CACHE_DIR),
    "XDG_CACHE_HOME": str(HF_CACHE_DIR),
    "HF_DATASETS_CACHE": str(HF_CACHE_DIR),
    "HF_MODULES_CACHE": str(HF_CACHE_DIR),
    "TMPDIR": str(HF_CACHE_DIR),
    "CACHE_DIR": str(HF_CACHE_DIR),
    "TORCH_HOME": str(HF_CACHE_DIR),
    "HOME": str(HF_CACHE_DIR)
})
import os.path
if not hasattr(os.path, "expanduser_original"):
    os.path.expanduser_original = os.path.expanduser
def safe_expanduser(path):
    if (
        path.startswith("~") or 
        path.startswith("/.cache") or 
        path.startswith("/root/.cache")
    ):
        return str(HF_CACHE_DIR)
    return os.path.expanduser_original(path)
os.path.expanduser = safe_expanduser


logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")

app = FastAPI(title="AI ADD Generator", version="1.0")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# ---------------------------
# Pydantic models
# ---------------------------
class IdeaRequest(BaseModel):
    idea: str

class ConfirmationRequest(BaseModel):
    task_id: str
    confirm: bool

# ---------------------------
# API endpoints
# ---------------------------
@app.post("/submit_idea")
async def submit_idea(request: IdeaRequest):
    task_id = await queue_manager.add_task(request.idea)
    return {"status": "submitted", "task_id": task_id}

@app.post("/confirm")
async def confirm_task(request: ConfirmationRequest):
    task = queue_manager.get_task_status(request.task_id)
    if not task:
        raise HTTPException(status_code=404, detail="Task not found")
    if task["status"] != queue_manager.TaskStatus.WAITING_CONFIRMATION:
        raise HTTPException(status_code=400, detail="Task not waiting for confirmation")

    await queue_manager.confirm_task(request.task_id)
    return {"status": "confirmed", "task": task}

@app.get("/status/{task_id}")
async def status(task_id: str):
    task = queue_manager.get_task_status(task_id)
    if not task:
        raise HTTPException(status_code=404, detail="Task not found")
    return task

@app.get("/")
async def health():
    return {"status": "running"}