karthikeya1212 commited on
Commit
c0bbf55
·
verified ·
1 Parent(s): 54c82a8

Update api/server.py

Browse files
Files changed (1) hide show
  1. api/server.py +63 -133
api/server.py CHANGED
@@ -190,9 +190,15 @@
190
 
191
 
192
 
193
-
 
 
 
 
194
  import os
195
  from pathlib import Path
 
 
196
 
197
  # CACHE PATCH BLOCK: place FIRST in pipeline.py!
198
  HF_CACHE_DIR = Path("/tmp/hf_cache")
@@ -210,33 +216,6 @@ os.environ.update({
210
  "TORCH_HOME": str(HF_CACHE_DIR),
211
  "HOME": str(HF_CACHE_DIR)
212
  })
213
-
214
-
215
- from fastapi import FastAPI, BackgroundTasks
216
- from fastapi.middleware.cors import CORSMiddleware
217
-
218
- app = FastAPI(title="AdMaker API")
219
-
220
- # Allow frontend access
221
- app.add_middleware(
222
- CORSMiddleware,
223
- allow_origins=["*"], # or restrict to your domain
224
- allow_credentials=True,
225
- allow_methods=["*"],
226
- allow_headers=["*"],
227
- )
228
-
229
- # Example routes
230
- @app.get("/")
231
- def root():
232
- return {"message": "AdMaker backend running!"}
233
-
234
- @app.get("/health")
235
- def health_check():
236
- return {"status": "ok"}
237
-
238
- # (Optional) Add your /submit_idea, /status/{task_id}, etc. here
239
-
240
  import os.path
241
  if not hasattr(os.path, "expanduser_original"):
242
  os.path.expanduser_original = os.path.expanduser
@@ -250,110 +229,61 @@ def safe_expanduser(path):
250
  return os.path.expanduser_original(path)
251
  os.path.expanduser = safe_expanduser
252
 
253
- import asyncio
254
- import logging
255
- import core.script_gen as script_gen
256
- import core.story_script as story_script
257
- # IMAGE: enabled by default (uncommented) — pipeline will run image generation and return images
258
- import core.image_generator as image_gen
259
- # VIDEO / MUSIC / ASSEMBLER placeholders: uncomment to enable those stages
260
- # import core.video_gen as video_gen
261
- # import core.music_gen as music_gen
262
- # import core.assemble as assemble
263
-
264
- logging.basicConfig(
265
- level=logging.INFO,
266
- format="%(asctime)s [%(levelname)s] %(message)s"
267
  )
268
 
269
- async def run_pipeline(task: dict, confirmation_event: asyncio.Event):
270
- """
271
- Executes the workflow and updates task['result'] after each stage so frontend
272
- can poll /status/{task_id} to get intermediate outputs.
273
-
274
- Behavior:
275
- - generate_script -> write task.result['script'] and set status to waiting_for_confirmation
276
- Frontend can edit and confirm the script.
277
- - on confirm -> generate story, write task.result['story_script']
278
- - generate images (image_gen) -> write task.result['images']
279
- - optional stages (video_gen, music_gen, assemble) are executed only if their modules are imported.
280
- Each stage updates task['result'] and task['status'] so frontend can receive intermediate outputs.
281
- """
282
- task_id = task["id"]
283
- idea = task["idea"]
284
-
285
- logging.info(f"[Pipeline] Starting script generation for task {task_id}")
286
- # 1) Script generation — update result so frontend can read it immediately
287
- script = await script_gen.generate_script(idea)
288
- task["result"]["script"] = script
289
- task["status"] = "waiting_for_confirmation"
290
- task["confirmation_required"] = True
291
- logging.info(f"[Pipeline] Script ready for task {task_id}; waiting for confirmation/edit...")
292
-
293
- # Wait for frontend confirmation (may include edited script saved into task["result"]["script"])
294
- await confirmation_event.wait()
295
-
296
- # confirmed, proceed
297
- task["status"] = "confirmed"
298
- task["confirmation_required"] = False
299
- logging.info(f"[Pipeline] Task {task_id} confirmed. Generating story based on confirmed script...")
300
-
301
- # 2) Story generation — use possibly edited/confirmed script
302
- confirmed_script = task["result"].get("script", script)
303
- story = await story_script.generate_story(confirmed_script)
304
- task["result"]["story_script"] = story
305
- task["status"] = "story_generated"
306
- logging.info(f"[Pipeline] Story ready for task {task_id}")
307
-
308
- # 3) Image generation (enabled)
309
- try:
310
- logging.info(f"[Pipeline] Generating images for task {task_id}")
311
- images = await image_gen.generate_images(story)
312
- task["result"]["images"] = images
313
- task["status"] = "images_generated"
314
- logging.info(f"[Pipeline] Images ready for task {task_id}")
315
- except Exception as e:
316
- # keep pipeline going; store error for this stage
317
- task.setdefault("stage_errors", {})["images"] = str(e)
318
- logging.exception(f"[Pipeline] Image generation failed for task {task_id}: {e}")
319
-
320
- # 4) Optional: video generation (uncomment import to enable)
321
- # if "video_gen" in globals():
322
- # try:
323
- # logging.info(f"[Pipeline] Generating video for task {task_id}")
324
- # video = await video_gen.generate_video(task["result"].get("images"))
325
- # task["result"]["video"] = video
326
- # task["status"] = "video_generated"
327
- # logging.info(f"[Pipeline] Video ready for task {task_id}")
328
- # except Exception as e:
329
- # task.setdefault("stage_errors", {})["video"] = str(e)
330
- # logging.exception(f"[Pipeline] Video generation failed for task {task_id}: {e}")
331
-
332
- # 5) Optional: music generation (uncomment import to enable)
333
- # if "music_gen" in globals():
334
- # try:
335
- # logging.info(f"[Pipeline] Generating music for task {task_id}")
336
- # music = await music_gen.generate_music(story)
337
- # task["result"]["music"] = music
338
- # task["status"] = "music_generated"
339
- # logging.info(f"[Pipeline] Music ready for task {task_id}")
340
- # except Exception as e:
341
- # task.setdefault("stage_errors", {})["music"] = str(e)
342
- # logging.exception(f"[Pipeline] Music generation failed for task {task_id}: {e}")
343
-
344
- # 6) Optional: assembler (uncomment import to enable)
345
- # if "assemble" in globals():
346
- # try:
347
- # logging.info(f"[Pipeline] Assembling final output for task {task_id}")
348
- # final_output = await assemble.create_final(task["result"])
349
- # task["result"]["final_output"] = final_output
350
- # task["status"] = "assembled"
351
- # logging.info(f"[Pipeline] Assembly complete for task {task_id}")
352
- # except Exception as e:
353
- # task.setdefault("stage_errors", {})["assemble"] = str(e)
354
- # logging.exception(f"[Pipeline] Assembly failed for task {task_id}: {e}")
355
-
356
- # Finalize
357
- task["status"] = "completed"
358
- logging.info(f"[Pipeline] Task {task_id} completed.")
359
- return task["result"]
 
190
 
191
 
192
 
193
+ import logging
194
+ from fastapi import FastAPI, HTTPException
195
+ from fastapi.middleware.cors import CORSMiddleware
196
+ from pydantic import BaseModel
197
+ from services import queue_manager
198
  import os
199
  from pathlib import Path
200
+ from typing import Optional
201
+
202
 
203
  # CACHE PATCH BLOCK: place FIRST in pipeline.py!
204
  HF_CACHE_DIR = Path("/tmp/hf_cache")
 
216
  "TORCH_HOME": str(HF_CACHE_DIR),
217
  "HOME": str(HF_CACHE_DIR)
218
  })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  import os.path
220
  if not hasattr(os.path, "expanduser_original"):
221
  os.path.expanduser_original = os.path.expanduser
 
229
  return os.path.expanduser_original(path)
230
  os.path.expanduser = safe_expanduser
231
 
232
+
233
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
234
+
235
+ app = FastAPI(title="AI ADD Generator", version="1.0")
236
+
237
+ app.add_middleware(
238
+ CORSMiddleware,
239
+ allow_origins=["*"],
240
+ allow_credentials=True,
241
+ allow_methods=["*"],
242
+ allow_headers=["*"],
 
 
 
243
  )
244
 
245
+ # ---------------------------
246
+ # Pydantic models
247
+ # ---------------------------
248
+ class IdeaRequest(BaseModel):
249
+ idea: str
250
+
251
+ class ConfirmationRequest(BaseModel):
252
+ task_id: str
253
+ confirm: bool
254
+ edited_script: Optional[str] = None
255
+
256
+ # ---------------------------
257
+ # API endpoints
258
+ # ---------------------------
259
+ @app.post("/submit_idea")
260
+ async def submit_idea(request: IdeaRequest):
261
+ task_id = await queue_manager.add_task(request.idea)
262
+ return {"status": "submitted", "task_id": task_id}
263
+
264
+ @app.post("/confirm")
265
+ async def confirm_task(request: ConfirmationRequest):
266
+ task = queue_manager.get_task_status(request.task_id)
267
+ if not task:
268
+ raise HTTPException(status_code=404, detail="Task not found")
269
+ # status values are stored as strings by queue_manager/pipeline
270
+ if task["status"] != queue_manager.TaskStatus.WAITING_CONFIRMATION.value:
271
+ raise HTTPException(status_code=400, detail="Task not waiting for confirmation")
272
+
273
+ # if frontend supplied an edited script, persist it before unblocking the pipeline
274
+ if request.edited_script:
275
+ task["result"]["script"] = request.edited_script
276
+
277
+ await queue_manager.confirm_task(request.task_id)
278
+ return {"status": "confirmed", "task": task}
279
+
280
+ @app.get("/status/{task_id}")
281
+ async def status(task_id: str):
282
+ task = queue_manager.get_task_status(task_id)
283
+ if not task:
284
+ raise HTTPException(status_code=404, detail="Task not found")
285
+ return task
286
+
287
+ @app.get("/")
288
+ async def health():
289
+ return {"status": "running"}