Corex / endpoints.py
yadavkapil23's picture
updated endp acc to OCR
d90bce4
raw
history blame
5.62 kB
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
from pydantic import BaseModel
from typing import List, Literal, Optional
import os
import tempfile
from pathlib import Path
router = APIRouter()
from rag import get_smart_rag_response
from ocr_service import get_ocr_service
# Pydantic models for request/response validation
class Message(BaseModel):
role: Literal["user", "assistant"]
content: str
class QueryRequest(BaseModel):
query: str
conversation_history: List[Message] = []
class QueryResponse(BaseModel):
query: str
response: str
source: str
# OCR-specific models
class OCRResponse(BaseModel):
success: bool
extracted_text: Optional[str] = None
markdown_text: Optional[str] = None
error: Optional[str] = None
image_path: Optional[str] = None
model_used: Optional[str] = None
class OCRQueryRequest(BaseModel):
query: str
conversation_history: List[Message] = []
extracted_text: str
@router.post("/query/")
async def query_rag_system(request: QueryRequest):
try:
# Convert Pydantic models to dicts for processing
history = [msg.dict() for msg in request.conversation_history]
response, source = await get_smart_rag_response(request.query, history)
return QueryResponse(
query=request.query,
response=response,
source=source
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# OCR Endpoints
@router.post("/ocr/extract-text/", response_model=OCRResponse)
async def extract_text_from_image(
file: UploadFile = File(...),
prompt: Optional[str] = Form(None)
):
"""
Extract text from uploaded image using DeepSeek OCR
"""
try:
# Validate file type
if not file.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="File must be an image")
# Create temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as temp_file:
content = await file.read()
temp_file.write(content)
temp_file_path = temp_file.name
try:
# Get OCR service and extract text
ocr_service = get_ocr_service()
result = ocr_service.extract_text_from_image(temp_file_path, prompt)
return OCRResponse(**result)
finally:
# Clean up temporary file
if os.path.exists(temp_file_path):
os.unlink(temp_file_path)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/ocr/extract-with-grounding/", response_model=OCRResponse)
async def extract_text_with_grounding(
file: UploadFile = File(...),
target_text: Optional[str] = Form(None)
):
"""
Extract text with grounding capabilities (locate specific text)
"""
try:
if not file.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="File must be an image")
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as temp_file:
content = await file.read()
temp_file.write(content)
temp_file_path = temp_file.name
try:
ocr_service = get_ocr_service()
result = ocr_service.extract_text_with_grounding(temp_file_path, target_text)
return OCRResponse(**result)
finally:
if os.path.exists(temp_file_path):
os.unlink(temp_file_path)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/ocr/convert-to-markdown/", response_model=OCRResponse)
async def convert_image_to_markdown(file: UploadFile = File(...)):
"""
Convert document image to markdown format
"""
try:
if not file.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="File must be an image")
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as temp_file:
content = await file.read()
temp_file.write(content)
temp_file_path = temp_file.name
try:
ocr_service = get_ocr_service()
result = ocr_service.convert_to_markdown(temp_file_path)
return OCRResponse(**result)
finally:
if os.path.exists(temp_file_path):
os.unlink(temp_file_path)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/ocr/query/", response_model=QueryResponse)
async def query_with_ocr_text(request: OCRQueryRequest):
"""
Query the RAG system with OCR extracted text
"""
try:
# Combine the original query with extracted text
combined_query = f"{request.query}\n\nExtracted text from image:\n{request.extracted_text}"
# Convert conversation history
history = [msg.dict() for msg in request.conversation_history]
# Get RAG response
response, source = await get_smart_rag_response(combined_query, history)
return QueryResponse(
query=request.query,
response=response,
source=f"{source} (OCR-enhanced)"
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))