yadavkapil23 commited on
Commit
d90bce4
·
1 Parent(s): 361dba0

updated endp acc to OCR

Browse files
Files changed (1) hide show
  1. endpoints.py +136 -2
endpoints.py CHANGED
@@ -1,10 +1,14 @@
1
- from fastapi import APIRouter, HTTPException
2
  from pydantic import BaseModel
3
- from typing import List, Literal
 
 
 
4
 
5
  router = APIRouter()
6
 
7
  from rag import get_smart_rag_response
 
8
 
9
  # Pydantic models for request/response validation
10
  class Message(BaseModel):
@@ -20,6 +24,20 @@ class QueryResponse(BaseModel):
20
  response: str
21
  source: str
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  @router.post("/query/")
24
  async def query_rag_system(request: QueryRequest):
25
  try:
@@ -33,3 +51,119 @@ async def query_rag_system(request: QueryRequest):
33
  )
34
  except Exception as e:
35
  raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException, UploadFile, File, Form
2
  from pydantic import BaseModel
3
+ from typing import List, Literal, Optional
4
+ import os
5
+ import tempfile
6
+ from pathlib import Path
7
 
8
  router = APIRouter()
9
 
10
  from rag import get_smart_rag_response
11
+ from ocr_service import get_ocr_service
12
 
13
  # Pydantic models for request/response validation
14
  class Message(BaseModel):
 
24
  response: str
25
  source: str
26
 
27
+ # OCR-specific models
28
+ class OCRResponse(BaseModel):
29
+ success: bool
30
+ extracted_text: Optional[str] = None
31
+ markdown_text: Optional[str] = None
32
+ error: Optional[str] = None
33
+ image_path: Optional[str] = None
34
+ model_used: Optional[str] = None
35
+
36
+ class OCRQueryRequest(BaseModel):
37
+ query: str
38
+ conversation_history: List[Message] = []
39
+ extracted_text: str
40
+
41
  @router.post("/query/")
42
  async def query_rag_system(request: QueryRequest):
43
  try:
 
51
  )
52
  except Exception as e:
53
  raise HTTPException(status_code=500, detail=str(e))
54
+
55
+ # OCR Endpoints
56
+ @router.post("/ocr/extract-text/", response_model=OCRResponse)
57
+ async def extract_text_from_image(
58
+ file: UploadFile = File(...),
59
+ prompt: Optional[str] = Form(None)
60
+ ):
61
+ """
62
+ Extract text from uploaded image using DeepSeek OCR
63
+ """
64
+ try:
65
+ # Validate file type
66
+ if not file.content_type.startswith('image/'):
67
+ raise HTTPException(status_code=400, detail="File must be an image")
68
+
69
+ # Create temporary file
70
+ with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as temp_file:
71
+ content = await file.read()
72
+ temp_file.write(content)
73
+ temp_file_path = temp_file.name
74
+
75
+ try:
76
+ # Get OCR service and extract text
77
+ ocr_service = get_ocr_service()
78
+ result = ocr_service.extract_text_from_image(temp_file_path, prompt)
79
+
80
+ return OCRResponse(**result)
81
+
82
+ finally:
83
+ # Clean up temporary file
84
+ if os.path.exists(temp_file_path):
85
+ os.unlink(temp_file_path)
86
+
87
+ except Exception as e:
88
+ raise HTTPException(status_code=500, detail=str(e))
89
+
90
+ @router.post("/ocr/extract-with-grounding/", response_model=OCRResponse)
91
+ async def extract_text_with_grounding(
92
+ file: UploadFile = File(...),
93
+ target_text: Optional[str] = Form(None)
94
+ ):
95
+ """
96
+ Extract text with grounding capabilities (locate specific text)
97
+ """
98
+ try:
99
+ if not file.content_type.startswith('image/'):
100
+ raise HTTPException(status_code=400, detail="File must be an image")
101
+
102
+ with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as temp_file:
103
+ content = await file.read()
104
+ temp_file.write(content)
105
+ temp_file_path = temp_file.name
106
+
107
+ try:
108
+ ocr_service = get_ocr_service()
109
+ result = ocr_service.extract_text_with_grounding(temp_file_path, target_text)
110
+
111
+ return OCRResponse(**result)
112
+
113
+ finally:
114
+ if os.path.exists(temp_file_path):
115
+ os.unlink(temp_file_path)
116
+
117
+ except Exception as e:
118
+ raise HTTPException(status_code=500, detail=str(e))
119
+
120
+ @router.post("/ocr/convert-to-markdown/", response_model=OCRResponse)
121
+ async def convert_image_to_markdown(file: UploadFile = File(...)):
122
+ """
123
+ Convert document image to markdown format
124
+ """
125
+ try:
126
+ if not file.content_type.startswith('image/'):
127
+ raise HTTPException(status_code=400, detail="File must be an image")
128
+
129
+ with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as temp_file:
130
+ content = await file.read()
131
+ temp_file.write(content)
132
+ temp_file_path = temp_file.name
133
+
134
+ try:
135
+ ocr_service = get_ocr_service()
136
+ result = ocr_service.convert_to_markdown(temp_file_path)
137
+
138
+ return OCRResponse(**result)
139
+
140
+ finally:
141
+ if os.path.exists(temp_file_path):
142
+ os.unlink(temp_file_path)
143
+
144
+ except Exception as e:
145
+ raise HTTPException(status_code=500, detail=str(e))
146
+
147
+ @router.post("/ocr/query/", response_model=QueryResponse)
148
+ async def query_with_ocr_text(request: OCRQueryRequest):
149
+ """
150
+ Query the RAG system with OCR extracted text
151
+ """
152
+ try:
153
+ # Combine the original query with extracted text
154
+ combined_query = f"{request.query}\n\nExtracted text from image:\n{request.extracted_text}"
155
+
156
+ # Convert conversation history
157
+ history = [msg.dict() for msg in request.conversation_history]
158
+
159
+ # Get RAG response
160
+ response, source = await get_smart_rag_response(combined_query, history)
161
+
162
+ return QueryResponse(
163
+ query=request.query,
164
+ response=response,
165
+ source=f"{source} (OCR-enhanced)"
166
+ )
167
+
168
+ except Exception as e:
169
+ raise HTTPException(status_code=500, detail=str(e))