Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse, JSONResponse, StreamingResponse | |
| from pydantic import BaseModel | |
| import os | |
| import tempfile | |
| import uvicorn | |
| from typing import List, Optional | |
| import logging | |
| from contextlib import asynccontextmanager | |
| # Import your existing RAG system | |
| from rag import RAG | |
| from vector_store import VectorStore | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Pydantic models | |
| class QuestionRequest(BaseModel): | |
| question: str | |
| class QuestionResponse(BaseModel): | |
| answer: str | |
| sources: Optional[List[str]] = [] | |
| class SearchRequest(BaseModel): | |
| query: str | |
| limit: Optional[int] = 5 | |
| class StatusResponse(BaseModel): | |
| status: str | |
| message: str | |
| version: str | |
| # Global variables | |
| rag_system = None | |
| async def lifespan(app: FastAPI): | |
| # Startup | |
| global rag_system | |
| try: | |
| # Initialize RAG system | |
| google_api_key = os.getenv("GOOGLE_API_KEY") | |
| if not google_api_key: | |
| raise ValueError("GOOGLE_API_KEY environment variable not set") | |
| collection_name = os.getenv("COLLECTION_NAME", "ca_documents") | |
| rag_system = RAG(google_api_key, collection_name) | |
| logger.info("RAG system initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize RAG system: {e}") | |
| raise | |
| yield | |
| # Shutdown | |
| logger.info("Shutting down...") | |
| # Create FastAPI app | |
| app = FastAPI( | |
| title="CA Study Assistant API", | |
| description="Backend API for the CA Study Assistant RAG system", | |
| version="2.0.0", | |
| lifespan=lifespan | |
| ) | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Health check endpoint | |
| async def health_check(): | |
| return {"status": "healthy", "message": "CA Study Assistant API is running"} | |
| # API Routes | |
| # @app.post("/api/ask", response_model=QuestionResponse) | |
| # async def ask_question(request: QuestionRequest): | |
| # """ | |
| # Ask a question to the RAG system | |
| # """ | |
| # try: | |
| # if not rag_system: | |
| # raise HTTPException(status_code=500, detail="RAG system not initialized") | |
| # logger.info(f"Processing question: {request.question[:100]}...") | |
| # answer = rag_system.ask_question(request.question) | |
| # # Extract sources from the answer if they exist | |
| # sources = [] | |
| # if "Sources:" in answer: | |
| # parts = answer.split("Sources:") | |
| # if len(parts) > 1: | |
| # answer = parts[0].strip() | |
| # sources_text = parts[1].strip() | |
| # sources = [s.strip() for s in sources_text.split(",") if s.strip()] | |
| # return QuestionResponse(answer=answer, sources=sources) | |
| # except Exception as e: | |
| # logger.error(f"Error processing question: {e}") | |
| # raise HTTPException(status_code=500, detail=f"Error processing question: {str(e)}") | |
| async def ask_question_stream(request: QuestionRequest): | |
| """ | |
| Ask a question to the RAG system and get a streaming response | |
| """ | |
| try: | |
| if not rag_system: | |
| raise HTTPException(status_code=500, detail="RAG system not initialized") | |
| logger.info(f"Processing streaming question: {request.question[:100]}...") | |
| async def event_generator(): | |
| try: | |
| for chunk in rag_system.ask_question_stream(request.question): | |
| if chunk: # Only yield non-empty chunks | |
| yield chunk | |
| except Exception as e: | |
| logger.error(f"Error during stream generation: {e}") | |
| # This part may not be sent if the connection is already closed. | |
| yield f"Error generating answer: {str(e)}" | |
| return StreamingResponse(event_generator(), media_type="text/plain") | |
| except Exception as e: | |
| logger.error(f"Error processing streaming question: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error processing streaming question: {str(e)}") | |
| async def upload_document(file: UploadFile = File(...)): | |
| """ | |
| Upload a document to the RAG system | |
| """ | |
| try: | |
| if not rag_system: | |
| raise HTTPException(status_code=500, detail="RAG system not initialized") | |
| # Validate file type | |
| allowed_extensions = ['.pdf', '.docx', '.txt'] | |
| file_extension = os.path.splitext(file.filename)[1].lower() | |
| if file_extension not in allowed_extensions: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Unsupported file type. Allowed types: {', '.join(allowed_extensions)}" | |
| ) | |
| # Create temporary file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file: | |
| content = await file.read() | |
| temp_file.write(content) | |
| temp_file_path = temp_file.name | |
| try: | |
| # Process the uploaded file | |
| logger.info(f"Processing uploaded file: {file.filename}") | |
| success = rag_system.upload_document(temp_file_path) | |
| if success: | |
| return { | |
| "status": "success", | |
| "message": f"File '{file.filename}' uploaded and processed successfully", | |
| "filename": file.filename, | |
| "size": len(content) | |
| } | |
| else: | |
| raise HTTPException(status_code=500, detail="Failed to process uploaded file") | |
| finally: | |
| # Clean up temporary file | |
| if os.path.exists(temp_file_path): | |
| os.unlink(temp_file_path) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error uploading document: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error uploading document: {str(e)}") | |
| async def search_documents(request: SearchRequest): | |
| """ | |
| Search for similar documents | |
| """ | |
| try: | |
| if not rag_system: | |
| raise HTTPException(status_code=500, detail="RAG system not initialized") | |
| results = rag_system.vector_store.search_similar(request.query, limit=request.limit) | |
| return { | |
| "status": "success", | |
| "results": results, | |
| "count": len(results) | |
| } | |
| except Exception as e: | |
| logger.error(f"Error searching documents: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error searching documents: {str(e)}") | |
| async def get_status(): | |
| """ | |
| Get system status | |
| """ | |
| try: | |
| status = "healthy" if rag_system else "unhealthy" | |
| message = "RAG system is operational" if rag_system else "RAG system not initialized" | |
| return StatusResponse( | |
| status=status, | |
| message=message, | |
| version="2.0.0" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error getting status: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error getting status: {str(e)}") | |
| async def get_collection_info(): | |
| """ | |
| Get information about the vector collection | |
| """ | |
| try: | |
| if not rag_system: | |
| raise HTTPException(status_code=500, detail="RAG system not initialized") | |
| info = rag_system.vector_store.get_collection_info() | |
| return { | |
| "status": "success", | |
| "collection_info": info | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting collection info: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error getting collection info: {str(e)}") | |
| frontend_build_path = "../frontend/build" | |
| if os.path.exists(frontend_build_path): | |
| app.mount("/static", StaticFiles(directory=f"{frontend_build_path}/static"), name="static") | |
| async def serve_react_app(request: Request, full_path: str): | |
| """ | |
| Serve React app for all non-API routes | |
| """ | |
| # If it's an API route, let FastAPI handle it | |
| if full_path.startswith("api/"): | |
| raise HTTPException(status_code=404, detail="API endpoint not found") | |
| # For static files (images, etc.) | |
| if "." in full_path: | |
| file_path = f"{frontend_build_path}/{full_path}" | |
| if os.path.exists(file_path): | |
| return FileResponse(file_path) | |
| else: | |
| raise HTTPException(status_code=404, detail="File not found") | |
| # For all other routes, serve index.html (React Router will handle it) | |
| return FileResponse(f"{frontend_build_path}/index.html") | |
| # Error handlers | |
| async def not_found_handler(request: Request, exc: HTTPException): | |
| if request.url.path.startswith("/api/"): | |
| return JSONResponse( | |
| status_code=404, | |
| content={"detail": "API endpoint not found"} | |
| ) | |
| # For non-API routes, serve React app | |
| if os.path.exists(f"{frontend_build_path}/index.html"): | |
| return FileResponse(f"{frontend_build_path}/index.html") | |
| else: | |
| return JSONResponse( | |
| status_code=404, | |
| content={"detail": "React app not built. Run 'npm run build' in the frontend directory."} | |
| ) | |
| async def internal_error_handler(request: Request, exc: Exception): | |
| logger.error(f"Internal server error: {exc}") | |
| return JSONResponse( | |
| status_code=500, | |
| content={"detail": "Internal server error"} | |
| ) | |
| if __name__ == "__main__": | |
| # Get port from environment or default to 8000 | |
| port = int(os.getenv("PORT", 8000)) | |
| uvicorn.run( | |
| "backend_api:app", | |
| host="0.0.0.0", | |
| port=port, | |
| reload=True, | |
| log_level="info" | |
| ) |