File size: 5,571 Bytes
d90bce4
b09d5e9
d90bce4
 
 
 
b09d5e9
 
 
 
d90bce4
b09d5e9
 
 
 
 
 
 
 
 
 
 
 
 
d90bce4
 
 
 
 
 
 
 
 
 
 
 
 
 
b09d5e9
 
 
 
 
 
 
 
 
 
 
 
 
d90bce4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
from pydantic import BaseModel
from typing import List, Literal, Optional
import os
import tempfile
from pathlib import Path

router = APIRouter()

from rag import get_smart_rag_response
from ocr_service import get_ocr_service
class Message(BaseModel):
    role: Literal["user", "assistant"]
    content: str

class QueryRequest(BaseModel):
    query: str
    conversation_history: List[Message] = []

class QueryResponse(BaseModel):
    query: str
    response: str
    source: str

# OCR-specific models
class OCRResponse(BaseModel):
    success: bool
    extracted_text: Optional[str] = None
    markdown_text: Optional[str] = None
    error: Optional[str] = None
    image_path: Optional[str] = None
    model_used: Optional[str] = None

class OCRQueryRequest(BaseModel):
    query: str
    conversation_history: List[Message] = []
    extracted_text: str

@router.post("/query/")
async def query_rag_system(request: QueryRequest):
    try:
        # Convert Pydantic models to dicts for processing
        history = [msg.dict() for msg in request.conversation_history]
        response, source = await get_smart_rag_response(request.query, history)
        return QueryResponse(
            query=request.query,
            response=response,
            source=source
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# OCR Endpoints
@router.post("/ocr/extract-text/", response_model=OCRResponse)
async def extract_text_from_image(
    file: UploadFile = File(...),
    prompt: Optional[str] = Form(None)
):
    """
    Extract text from uploaded image using DeepSeek OCR
    """
    try:
        # Validate file type
        if not file.content_type.startswith('image/'):
            raise HTTPException(status_code=400, detail="File must be an image")
        
        # Create temporary file
        with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as temp_file:
            content = await file.read()
            temp_file.write(content)
            temp_file_path = temp_file.name
        
        try:
            # Get OCR service and extract text
            ocr_service = get_ocr_service()
            result = ocr_service.extract_text_from_image(temp_file_path, prompt)
            
            return OCRResponse(**result)
            
        finally:
            # Clean up temporary file
            if os.path.exists(temp_file_path):
                os.unlink(temp_file_path)
                
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@router.post("/ocr/extract-with-grounding/", response_model=OCRResponse)
async def extract_text_with_grounding(
    file: UploadFile = File(...),
    target_text: Optional[str] = Form(None)
):
    """
    Extract text with grounding capabilities (locate specific text)
    """
    try:
        if not file.content_type.startswith('image/'):
            raise HTTPException(status_code=400, detail="File must be an image")
        
        with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as temp_file:
            content = await file.read()
            temp_file.write(content)
            temp_file_path = temp_file.name
        
        try:
            ocr_service = get_ocr_service()
            result = ocr_service.extract_text_with_grounding(temp_file_path, target_text)
            
            return OCRResponse(**result)
            
        finally:
            if os.path.exists(temp_file_path):
                os.unlink(temp_file_path)
                
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@router.post("/ocr/convert-to-markdown/", response_model=OCRResponse)
async def convert_image_to_markdown(file: UploadFile = File(...)):
    """
    Convert document image to markdown format
    """
    try:
        if not file.content_type.startswith('image/'):
            raise HTTPException(status_code=400, detail="File must be an image")
        
        with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as temp_file:
            content = await file.read()
            temp_file.write(content)
            temp_file_path = temp_file.name
        
        try:
            ocr_service = get_ocr_service()
            result = ocr_service.convert_to_markdown(temp_file_path)
            
            return OCRResponse(**result)
            
        finally:
            if os.path.exists(temp_file_path):
                os.unlink(temp_file_path)
                
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@router.post("/ocr/query/", response_model=QueryResponse)
async def query_with_ocr_text(request: OCRQueryRequest):
    """
    Query the RAG system with OCR extracted text
    """
    try:
        # Combine the original query with extracted text
        combined_query = f"{request.query}\n\nExtracted text from image:\n{request.extracted_text}"
        
        # Convert conversation history
        history = [msg.dict() for msg in request.conversation_history]
        
        # Get RAG response
        response, source = await get_smart_rag_response(combined_query, history)
        
        return QueryResponse(
            query=request.query,
            response=response,
            source=f"{source} (OCR-enhanced)"
        )
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))