Spaces:
Running
Running
File size: 5,571 Bytes
d90bce4 b09d5e9 d90bce4 b09d5e9 d90bce4 b09d5e9 d90bce4 b09d5e9 d90bce4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
from pydantic import BaseModel
from typing import List, Literal, Optional
import os
import tempfile
from pathlib import Path
router = APIRouter()
from rag import get_smart_rag_response
from ocr_service import get_ocr_service
class Message(BaseModel):
role: Literal["user", "assistant"]
content: str
class QueryRequest(BaseModel):
query: str
conversation_history: List[Message] = []
class QueryResponse(BaseModel):
query: str
response: str
source: str
# OCR-specific models
class OCRResponse(BaseModel):
success: bool
extracted_text: Optional[str] = None
markdown_text: Optional[str] = None
error: Optional[str] = None
image_path: Optional[str] = None
model_used: Optional[str] = None
class OCRQueryRequest(BaseModel):
query: str
conversation_history: List[Message] = []
extracted_text: str
@router.post("/query/")
async def query_rag_system(request: QueryRequest):
try:
# Convert Pydantic models to dicts for processing
history = [msg.dict() for msg in request.conversation_history]
response, source = await get_smart_rag_response(request.query, history)
return QueryResponse(
query=request.query,
response=response,
source=source
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# OCR Endpoints
@router.post("/ocr/extract-text/", response_model=OCRResponse)
async def extract_text_from_image(
file: UploadFile = File(...),
prompt: Optional[str] = Form(None)
):
"""
Extract text from uploaded image using DeepSeek OCR
"""
try:
# Validate file type
if not file.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="File must be an image")
# Create temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as temp_file:
content = await file.read()
temp_file.write(content)
temp_file_path = temp_file.name
try:
# Get OCR service and extract text
ocr_service = get_ocr_service()
result = ocr_service.extract_text_from_image(temp_file_path, prompt)
return OCRResponse(**result)
finally:
# Clean up temporary file
if os.path.exists(temp_file_path):
os.unlink(temp_file_path)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/ocr/extract-with-grounding/", response_model=OCRResponse)
async def extract_text_with_grounding(
file: UploadFile = File(...),
target_text: Optional[str] = Form(None)
):
"""
Extract text with grounding capabilities (locate specific text)
"""
try:
if not file.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="File must be an image")
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as temp_file:
content = await file.read()
temp_file.write(content)
temp_file_path = temp_file.name
try:
ocr_service = get_ocr_service()
result = ocr_service.extract_text_with_grounding(temp_file_path, target_text)
return OCRResponse(**result)
finally:
if os.path.exists(temp_file_path):
os.unlink(temp_file_path)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/ocr/convert-to-markdown/", response_model=OCRResponse)
async def convert_image_to_markdown(file: UploadFile = File(...)):
"""
Convert document image to markdown format
"""
try:
if not file.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="File must be an image")
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as temp_file:
content = await file.read()
temp_file.write(content)
temp_file_path = temp_file.name
try:
ocr_service = get_ocr_service()
result = ocr_service.convert_to_markdown(temp_file_path)
return OCRResponse(**result)
finally:
if os.path.exists(temp_file_path):
os.unlink(temp_file_path)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/ocr/query/", response_model=QueryResponse)
async def query_with_ocr_text(request: OCRQueryRequest):
"""
Query the RAG system with OCR extracted text
"""
try:
# Combine the original query with extracted text
combined_query = f"{request.query}\n\nExtracted text from image:\n{request.extracted_text}"
# Convert conversation history
history = [msg.dict() for msg in request.conversation_history]
# Get RAG response
response, source = await get_smart_rag_response(combined_query, history)
return QueryResponse(
query=request.query,
response=response,
source=f"{source} (OCR-enhanced)"
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
|