Corex / endpoints.py
yadavkapil23's picture
udpated
9f6ba21
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
from pydantic import BaseModel
from typing import List, Literal, Optional
import os
import tempfile
from pathlib import Path
router = APIRouter()
from rag import get_smart_rag_response
from ocr_service import get_ocr_service
class Message(BaseModel):
role: Literal["user", "assistant"]
content: str
class QueryRequest(BaseModel):
query: str
conversation_history: List[Message] = []
class QueryResponse(BaseModel):
query: str
response: str
source: str
# OCR-specific models
class OCRResponse(BaseModel):
success: bool
extracted_text: Optional[str] = None
markdown_text: Optional[str] = None
error: Optional[str] = None
image_path: Optional[str] = None
model_used: Optional[str] = None
class OCRQueryRequest(BaseModel):
query: str
conversation_history: List[Message] = []
extracted_text: str
@router.post("/query/")
async def query_rag_system(request: QueryRequest):
try:
# Convert Pydantic models to dicts for processing
history = [msg.dict() for msg in request.conversation_history]
response, source = await get_smart_rag_response(request.query, history)
return QueryResponse(
query=request.query,
response=response,
source=source
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# OCR Endpoints
@router.post("/ocr/extract-text/", response_model=OCRResponse)
async def extract_text_from_image(
file: UploadFile = File(...),
prompt: Optional[str] = Form(None)
):
"""
Extract text from uploaded image using DeepSeek OCR
"""
try:
# Validate file type
if not file.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="File must be an image")
# Create temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as temp_file:
content = await file.read()
temp_file.write(content)
temp_file_path = temp_file.name
try:
# Get OCR service and extract text
ocr_service = get_ocr_service()
result = ocr_service.extract_text_from_image(temp_file_path, prompt)
return OCRResponse(**result)
finally:
# Clean up temporary file
if os.path.exists(temp_file_path):
os.unlink(temp_file_path)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/ocr/extract-with-grounding/", response_model=OCRResponse)
async def extract_text_with_grounding(
file: UploadFile = File(...),
target_text: Optional[str] = Form(None)
):
"""
Extract text with grounding capabilities (locate specific text)
"""
try:
if not file.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="File must be an image")
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as temp_file:
content = await file.read()
temp_file.write(content)
temp_file_path = temp_file.name
try:
ocr_service = get_ocr_service()
result = ocr_service.extract_text_with_grounding(temp_file_path, target_text)
return OCRResponse(**result)
finally:
if os.path.exists(temp_file_path):
os.unlink(temp_file_path)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/ocr/convert-to-markdown/", response_model=OCRResponse)
async def convert_image_to_markdown(file: UploadFile = File(...)):
"""
Convert document image to markdown format
"""
try:
if not file.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="File must be an image")
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as temp_file:
content = await file.read()
temp_file.write(content)
temp_file_path = temp_file.name
try:
ocr_service = get_ocr_service()
result = ocr_service.convert_to_markdown(temp_file_path)
return OCRResponse(**result)
finally:
if os.path.exists(temp_file_path):
os.unlink(temp_file_path)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/ocr/query/", response_model=QueryResponse)
async def query_with_ocr_text(request: OCRQueryRequest):
"""
Query the RAG system with OCR extracted text
"""
try:
# Combine the original query with extracted text
combined_query = f"{request.query}\n\nExtracted text from image:\n{request.extracted_text}"
# Convert conversation history
history = [msg.dict() for msg in request.conversation_history]
# Get RAG response
response, source = await get_smart_rag_response(combined_query, history)
return QueryResponse(
query=request.query,
response=response,
source=f"{source} (OCR-enhanced)"
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))