Spaces:
Running
Running
| from fastapi import APIRouter, HTTPException, UploadFile, File, Form | |
| from pydantic import BaseModel | |
| from typing import List, Literal, Optional | |
| import os | |
| import tempfile | |
| from pathlib import Path | |
| router = APIRouter() | |
| from rag import get_smart_rag_response | |
| from ocr_service import get_ocr_service | |
| class Message(BaseModel): | |
| role: Literal["user", "assistant"] | |
| content: str | |
| class QueryRequest(BaseModel): | |
| query: str | |
| conversation_history: List[Message] = [] | |
| class QueryResponse(BaseModel): | |
| query: str | |
| response: str | |
| source: str | |
| # OCR-specific models | |
| class OCRResponse(BaseModel): | |
| success: bool | |
| extracted_text: Optional[str] = None | |
| markdown_text: Optional[str] = None | |
| error: Optional[str] = None | |
| image_path: Optional[str] = None | |
| model_used: Optional[str] = None | |
| class OCRQueryRequest(BaseModel): | |
| query: str | |
| conversation_history: List[Message] = [] | |
| extracted_text: str | |
| async def query_rag_system(request: QueryRequest): | |
| try: | |
| # Convert Pydantic models to dicts for processing | |
| history = [msg.dict() for msg in request.conversation_history] | |
| response, source = await get_smart_rag_response(request.query, history) | |
| return QueryResponse( | |
| query=request.query, | |
| response=response, | |
| source=source | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # OCR Endpoints | |
| async def extract_text_from_image( | |
| file: UploadFile = File(...), | |
| prompt: Optional[str] = Form(None) | |
| ): | |
| """ | |
| Extract text from uploaded image using DeepSeek OCR | |
| """ | |
| try: | |
| # Validate file type | |
| if not file.content_type.startswith('image/'): | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| # Create temporary file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as temp_file: | |
| content = await file.read() | |
| temp_file.write(content) | |
| temp_file_path = temp_file.name | |
| try: | |
| # Get OCR service and extract text | |
| ocr_service = get_ocr_service() | |
| result = ocr_service.extract_text_from_image(temp_file_path, prompt) | |
| return OCRResponse(**result) | |
| finally: | |
| # Clean up temporary file | |
| if os.path.exists(temp_file_path): | |
| os.unlink(temp_file_path) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def extract_text_with_grounding( | |
| file: UploadFile = File(...), | |
| target_text: Optional[str] = Form(None) | |
| ): | |
| """ | |
| Extract text with grounding capabilities (locate specific text) | |
| """ | |
| try: | |
| if not file.content_type.startswith('image/'): | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as temp_file: | |
| content = await file.read() | |
| temp_file.write(content) | |
| temp_file_path = temp_file.name | |
| try: | |
| ocr_service = get_ocr_service() | |
| result = ocr_service.extract_text_with_grounding(temp_file_path, target_text) | |
| return OCRResponse(**result) | |
| finally: | |
| if os.path.exists(temp_file_path): | |
| os.unlink(temp_file_path) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def convert_image_to_markdown(file: UploadFile = File(...)): | |
| """ | |
| Convert document image to markdown format | |
| """ | |
| try: | |
| if not file.content_type.startswith('image/'): | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as temp_file: | |
| content = await file.read() | |
| temp_file.write(content) | |
| temp_file_path = temp_file.name | |
| try: | |
| ocr_service = get_ocr_service() | |
| result = ocr_service.convert_to_markdown(temp_file_path) | |
| return OCRResponse(**result) | |
| finally: | |
| if os.path.exists(temp_file_path): | |
| os.unlink(temp_file_path) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def query_with_ocr_text(request: OCRQueryRequest): | |
| """ | |
| Query the RAG system with OCR extracted text | |
| """ | |
| try: | |
| # Combine the original query with extracted text | |
| combined_query = f"{request.query}\n\nExtracted text from image:\n{request.extracted_text}" | |
| # Convert conversation history | |
| history = [msg.dict() for msg in request.conversation_history] | |
| # Get RAG response | |
| response, source = await get_smart_rag_response(combined_query, history) | |
| return QueryResponse( | |
| query=request.query, | |
| response=response, | |
| source=f"{source} (OCR-enhanced)" | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |