from fastapi import APIRouter, HTTPException, UploadFile, File, Form from pydantic import BaseModel from typing import List, Literal, Optional import os import tempfile from pathlib import Path router = APIRouter() from rag import get_smart_rag_response from ocr_service import get_ocr_service class Message(BaseModel): role: Literal["user", "assistant"] content: str class QueryRequest(BaseModel): query: str conversation_history: List[Message] = [] class QueryResponse(BaseModel): query: str response: str source: str # OCR-specific models class OCRResponse(BaseModel): success: bool extracted_text: Optional[str] = None markdown_text: Optional[str] = None error: Optional[str] = None image_path: Optional[str] = None model_used: Optional[str] = None class OCRQueryRequest(BaseModel): query: str conversation_history: List[Message] = [] extracted_text: str @router.post("/query/") async def query_rag_system(request: QueryRequest): try: # Convert Pydantic models to dicts for processing history = [msg.dict() for msg in request.conversation_history] response, source = await get_smart_rag_response(request.query, history) return QueryResponse( query=request.query, response=response, source=source ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # OCR Endpoints @router.post("/ocr/extract-text/", response_model=OCRResponse) async def extract_text_from_image( file: UploadFile = File(...), prompt: Optional[str] = Form(None) ): """ Extract text from uploaded image using DeepSeek OCR """ try: # Validate file type if not file.content_type.startswith('image/'): raise HTTPException(status_code=400, detail="File must be an image") # Create temporary file with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as temp_file: content = await file.read() temp_file.write(content) temp_file_path = temp_file.name try: # Get OCR service and extract text ocr_service = get_ocr_service() result = ocr_service.extract_text_from_image(temp_file_path, prompt) return OCRResponse(**result) finally: # Clean up temporary file if os.path.exists(temp_file_path): os.unlink(temp_file_path) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.post("/ocr/extract-with-grounding/", response_model=OCRResponse) async def extract_text_with_grounding( file: UploadFile = File(...), target_text: Optional[str] = Form(None) ): """ Extract text with grounding capabilities (locate specific text) """ try: if not file.content_type.startswith('image/'): raise HTTPException(status_code=400, detail="File must be an image") with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as temp_file: content = await file.read() temp_file.write(content) temp_file_path = temp_file.name try: ocr_service = get_ocr_service() result = ocr_service.extract_text_with_grounding(temp_file_path, target_text) return OCRResponse(**result) finally: if os.path.exists(temp_file_path): os.unlink(temp_file_path) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.post("/ocr/convert-to-markdown/", response_model=OCRResponse) async def convert_image_to_markdown(file: UploadFile = File(...)): """ Convert document image to markdown format """ try: if not file.content_type.startswith('image/'): raise HTTPException(status_code=400, detail="File must be an image") with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as temp_file: content = await file.read() temp_file.write(content) temp_file_path = temp_file.name try: ocr_service = get_ocr_service() result = ocr_service.convert_to_markdown(temp_file_path) return OCRResponse(**result) finally: if os.path.exists(temp_file_path): os.unlink(temp_file_path) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.post("/ocr/query/", response_model=QueryResponse) async def query_with_ocr_text(request: OCRQueryRequest): """ Query the RAG system with OCR extracted text """ try: # Combine the original query with extracted text combined_query = f"{request.query}\n\nExtracted text from image:\n{request.extracted_text}" # Convert conversation history history = [msg.dict() for msg in request.conversation_history] # Get RAG response response, source = await get_smart_rag_response(combined_query, history) return QueryResponse( query=request.query, response=response, source=f"{source} (OCR-enhanced)" ) except Exception as e: raise HTTPException(status_code=500, detail=str(e))