Spaces:
Sleeping
Sleeping
Sandy2636
commited on
Commit
·
2d6f97d
1
Parent(s):
1d51bfe
Add application file
Browse files
app.py
CHANGED
|
@@ -15,35 +15,26 @@ OPENROUTER_API_URL = "https://openrouter.ai/api/v1/chat/completions"
|
|
| 15 |
|
| 16 |
# --- Global State (managed within Gradio's session if possible, or module-level for simplicity here) ---
|
| 17 |
# This will be reset each time the processing function is called.
|
| 18 |
-
# For a multi-user or more robust app, session state or a proper backend DB would be needed.
|
| 19 |
processed_files_data = [] # Stores dicts for each file's details and status
|
| 20 |
person_profiles = {} # Stores dicts for each identified person and their documents
|
| 21 |
|
| 22 |
# --- Helper Functions ---
|
| 23 |
|
| 24 |
def extract_json_from_text(text):
|
| 25 |
-
"""
|
| 26 |
-
Extracts a JSON object from a string, trying common markdown and direct JSON.
|
| 27 |
-
"""
|
| 28 |
if not text:
|
| 29 |
return {"error": "Empty text provided for JSON extraction."}
|
| 30 |
-
|
| 31 |
-
# Try to match ```json ... ``` code block
|
| 32 |
match_block = re.search(r"```json\s*(\{.*?\})\s*```", text, re.DOTALL | re.IGNORECASE)
|
| 33 |
if match_block:
|
| 34 |
json_str = match_block.group(1)
|
| 35 |
else:
|
| 36 |
-
# If no block, assume the text itself might be JSON or wrapped in single backticks
|
| 37 |
text_stripped = text.strip()
|
| 38 |
if text_stripped.startswith("`") and text_stripped.endswith("`"):
|
| 39 |
json_str = text_stripped[1:-1]
|
| 40 |
else:
|
| 41 |
-
json_str = text_stripped
|
| 42 |
-
|
| 43 |
try:
|
| 44 |
return json.loads(json_str)
|
| 45 |
except json.JSONDecodeError as e:
|
| 46 |
-
# Fallback: Try to find the first '{' and last '}' if initial parsing fails
|
| 47 |
try:
|
| 48 |
first_brace = json_str.find('{')
|
| 49 |
last_brace = json_str.rfind('}')
|
|
@@ -55,7 +46,6 @@ def extract_json_from_text(text):
|
|
| 55 |
except json.JSONDecodeError as e2:
|
| 56 |
return {"error": f"Invalid JSON structure after attempting substring: {str(e2)}", "original_text": text}
|
| 57 |
|
| 58 |
-
|
| 59 |
def get_ocr_prompt():
|
| 60 |
return f"""You are an advanced OCR and information extraction AI.
|
| 61 |
Your task is to meticulously analyze this image and extract all relevant information.
|
|
@@ -89,17 +79,13 @@ def call_openrouter_ocr(image_filepath):
|
|
| 89 |
try:
|
| 90 |
with open(image_filepath, "rb") as f:
|
| 91 |
encoded_image = base64.b64encode(f.read()).decode("utf-8")
|
| 92 |
-
|
| 93 |
-
# Basic MIME type guessing, default to jpeg
|
| 94 |
mime_type = "image/jpeg"
|
| 95 |
if image_filepath.lower().endswith(".png"):
|
| 96 |
mime_type = "image/png"
|
| 97 |
elif image_filepath.lower().endswith(".webp"):
|
| 98 |
mime_type = "image/webp"
|
| 99 |
-
|
| 100 |
data_url = f"data:{mime_type};base64,{encoded_image}"
|
| 101 |
prompt_text = get_ocr_prompt()
|
| 102 |
-
|
| 103 |
payload = {
|
| 104 |
"model": IMAGE_MODEL,
|
| 105 |
"messages": [
|
|
@@ -111,26 +97,23 @@ def call_openrouter_ocr(image_filepath):
|
|
| 111 |
]
|
| 112 |
}
|
| 113 |
],
|
| 114 |
-
"max_tokens": 3500,
|
| 115 |
"temperature": 0.1,
|
| 116 |
}
|
| 117 |
headers = {
|
| 118 |
"Authorization": f"Bearer {OPENROUTER_API_KEY}",
|
| 119 |
"Content-Type": "application/json",
|
| 120 |
-
"HTTP-Referer": "https://huggingface.co/spaces/
|
| 121 |
-
"X-Title": "
|
| 122 |
}
|
| 123 |
-
|
| 124 |
-
response = requests.post(OPENROUTER_API_URL, headers=headers, json=payload, timeout=180) # 3 min timeout
|
| 125 |
response.raise_for_status()
|
| 126 |
result = response.json()
|
| 127 |
-
|
| 128 |
if "choices" in result and result["choices"]:
|
| 129 |
raw_content = result["choices"][0]["message"]["content"]
|
| 130 |
return extract_json_from_text(raw_content)
|
| 131 |
else:
|
| 132 |
return {"error": "No 'choices' in API response from OpenRouter.", "details": result}
|
| 133 |
-
|
| 134 |
except requests.exceptions.Timeout:
|
| 135 |
return {"error": "API request timed out."}
|
| 136 |
except requests.exceptions.RequestException as e:
|
|
@@ -142,44 +125,38 @@ def call_openrouter_ocr(image_filepath):
|
|
| 142 |
return {"error": f"An unexpected error occurred during OCR: {str(e)}"}
|
| 143 |
|
| 144 |
def extract_entities_from_ocr(ocr_json):
|
| 145 |
-
if not ocr_json or "extracted_fields" not in ocr_json or not isinstance(ocr_json
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
fields = ocr_json["extracted_fields"]
|
| 149 |
doc_type = ocr_json.get("document_type_detected", "Unknown")
|
| 150 |
-
|
| 151 |
-
# Normalize potential field names (case-insensitive search)
|
| 152 |
name_keys = ["full name", "name", "account holder name", "guest name"]
|
| 153 |
dob_keys = ["date of birth", "dob"]
|
| 154 |
passport_keys = ["document number", "passport number"]
|
| 155 |
-
|
| 156 |
extracted_name = None
|
| 157 |
for key in name_keys:
|
| 158 |
for field_key, value in fields.items():
|
| 159 |
if key == field_key.lower():
|
| 160 |
extracted_name = str(value) if value else None
|
| 161 |
break
|
| 162 |
-
if extracted_name:
|
| 163 |
-
break
|
| 164 |
-
|
| 165 |
extracted_dob = None
|
| 166 |
for key in dob_keys:
|
| 167 |
for field_key, value in fields.items():
|
| 168 |
if key == field_key.lower():
|
| 169 |
extracted_dob = str(value) if value else None
|
| 170 |
break
|
| 171 |
-
if extracted_dob:
|
| 172 |
-
break
|
| 173 |
-
|
| 174 |
extracted_passport_no = None
|
| 175 |
for key in passport_keys:
|
| 176 |
for field_key, value in fields.items():
|
| 177 |
if key == field_key.lower():
|
| 178 |
-
extracted_passport_no = str(value).replace(" ", "").upper() if value else None
|
| 179 |
break
|
| 180 |
-
if extracted_passport_no:
|
| 181 |
-
break
|
| 182 |
-
|
| 183 |
return {
|
| 184 |
"name": extracted_name,
|
| 185 |
"dob": extracted_dob,
|
|
@@ -192,64 +169,42 @@ def normalize_name(name):
|
|
| 192 |
return "".join(filter(str.isalnum, name)).lower()
|
| 193 |
|
| 194 |
def get_person_id_and_update_profiles(doc_id, entities, current_persons_data):
|
| 195 |
-
"""
|
| 196 |
-
Tries to assign a document to an existing person or creates a new one.
|
| 197 |
-
Returns a person_key.
|
| 198 |
-
Updates current_persons_data in place.
|
| 199 |
-
"""
|
| 200 |
passport_no = entities.get("passport_no")
|
| 201 |
name = entities.get("name")
|
| 202 |
dob = entities.get("dob")
|
| 203 |
-
|
| 204 |
-
# 1. Match by Passport Number (strongest identifier)
|
| 205 |
if passport_no:
|
| 206 |
for p_key, p_data in current_persons_data.items():
|
| 207 |
if passport_no in p_data.get("passport_numbers", set()):
|
| 208 |
p_data["doc_ids"].add(doc_id)
|
| 209 |
-
# Update person profile with potentially new name/dob if current is missing
|
| 210 |
if name and not p_data.get("canonical_name"): p_data["canonical_name"] = name
|
| 211 |
if dob and not p_data.get("canonical_dob"): p_data["canonical_dob"] = dob
|
| 212 |
return p_key
|
| 213 |
-
|
| 214 |
-
new_person_key = f"person_{passport_no}" # Or more robust ID generation
|
| 215 |
current_persons_data[new_person_key] = {
|
| 216 |
-
"canonical_name": name,
|
| 217 |
-
"canonical_dob": dob,
|
| 218 |
"names": {normalize_name(name)} if name else set(),
|
| 219 |
"dobs": {dob} if dob else set(),
|
| 220 |
-
"passport_numbers": {passport_no},
|
| 221 |
-
"doc_ids": {doc_id},
|
| 222 |
"display_name": name or f"Person (ID: {passport_no})"
|
| 223 |
}
|
| 224 |
return new_person_key
|
| 225 |
-
|
| 226 |
-
# 2. Match by Normalized Name + DOB (if passport not found or not present)
|
| 227 |
if name and dob:
|
| 228 |
norm_name = normalize_name(name)
|
| 229 |
composite_key_nd = f"{norm_name}_{dob}"
|
| 230 |
for p_key, p_data in current_persons_data.items():
|
| 231 |
-
# Check if this name and dob combo has been seen for this person
|
| 232 |
if norm_name in p_data.get("names", set()) and dob in p_data.get("dobs", set()):
|
| 233 |
p_data["doc_ids"].add(doc_id)
|
| 234 |
return p_key
|
| 235 |
-
# New person based on name and DOB
|
| 236 |
new_person_key = f"person_{composite_key_nd}_{str(uuid.uuid4())[:4]}"
|
| 237 |
current_persons_data[new_person_key] = {
|
| 238 |
-
"canonical_name": name,
|
| 239 |
-
"
|
| 240 |
-
"
|
| 241 |
-
"dobs": {dob},
|
| 242 |
-
"passport_numbers": set(),
|
| 243 |
-
"doc_ids": {doc_id},
|
| 244 |
"display_name": name
|
| 245 |
}
|
| 246 |
return new_person_key
|
| 247 |
-
|
| 248 |
-
# 3. If only name, less reliable, create new person (could add fuzzy matching later)
|
| 249 |
if name:
|
| 250 |
norm_name = normalize_name(name)
|
| 251 |
-
# Check if a person with just this name exists and has no other strong identifiers yet
|
| 252 |
-
# This part can be made more robust, for now, it might create more splits
|
| 253 |
new_person_key = f"person_{norm_name}_{str(uuid.uuid4())[:4]}"
|
| 254 |
current_persons_data[new_person_key] = {
|
| 255 |
"canonical_name": name, "canonical_dob": None,
|
|
@@ -257,8 +212,6 @@ def get_person_id_and_update_profiles(doc_id, entities, current_persons_data):
|
|
| 257 |
"doc_ids": {doc_id}, "display_name": name
|
| 258 |
}
|
| 259 |
return new_person_key
|
| 260 |
-
|
| 261 |
-
# 4. Unclassifiable for now, assign a generic unique person key
|
| 262 |
generic_person_key = f"unidentified_person_{str(uuid.uuid4())[:6]}"
|
| 263 |
current_persons_data[generic_person_key] = {
|
| 264 |
"canonical_name": "Unknown", "canonical_dob": None,
|
|
@@ -267,17 +220,14 @@ def get_person_id_and_update_profiles(doc_id, entities, current_persons_data):
|
|
| 267 |
}
|
| 268 |
return generic_person_key
|
| 269 |
|
| 270 |
-
|
| 271 |
def format_dataframe_data(current_files_data):
|
| 272 |
-
# Headers for the dataframe
|
| 273 |
-
# "ID", "Filename", "Status", "Detected Type", "Extracted Name", "Extracted DOB", "Main ID", "Person Key"
|
| 274 |
df_rows = []
|
| 275 |
for f_data in current_files_data:
|
| 276 |
-
entities = f_data.get("entities") or {}
|
| 277 |
df_rows.append([
|
| 278 |
-
f_data
|
| 279 |
-
f_data
|
| 280 |
-
f_data
|
| 281 |
entities.get("doc_type", "N/A"),
|
| 282 |
entities.get("name", "N/A"),
|
| 283 |
entities.get("dob", "N/A"),
|
|
@@ -289,37 +239,33 @@ def format_dataframe_data(current_files_data):
|
|
| 289 |
def format_persons_markdown(current_persons_data, current_files_data):
|
| 290 |
if not current_persons_data:
|
| 291 |
return "No persons identified yet."
|
| 292 |
-
|
| 293 |
md_parts = ["## Classified Persons & Documents\n"]
|
| 294 |
for p_key, p_data in current_persons_data.items():
|
| 295 |
display_name = p_data.get('display_name', p_key)
|
| 296 |
md_parts.append(f"### Person: {display_name} (Profile Key: {p_key})")
|
| 297 |
if p_data.get("canonical_dob"): md_parts.append(f"* DOB: {p_data['canonical_dob']}")
|
| 298 |
if p_data.get("passport_numbers"): md_parts.append(f"* Passport(s): {', '.join(p_data['passport_numbers'])}")
|
| 299 |
-
|
| 300 |
md_parts.append("* Documents:")
|
| 301 |
doc_ids_for_person = p_data.get("doc_ids", set())
|
| 302 |
if doc_ids_for_person:
|
| 303 |
for doc_id in doc_ids_for_person:
|
| 304 |
-
# Find the filename and detected type from current_files_data
|
| 305 |
doc_detail = next((f for f in current_files_data if f["doc_id"] == doc_id), None)
|
| 306 |
if doc_detail:
|
| 307 |
-
filename = doc_detail
|
| 308 |
-
|
|
|
|
| 309 |
md_parts.append(f" - {filename} (`{doc_type}`)")
|
| 310 |
else:
|
| 311 |
-
md_parts.append(f" - Document ID: {doc_id[:8]} (details
|
| 312 |
else:
|
| 313 |
md_parts.append(" - No documents currently assigned.")
|
| 314 |
md_parts.append("\n---\n")
|
| 315 |
return "\n".join(md_parts)
|
| 316 |
|
| 317 |
-
# --- Main Gradio Processing Function (Generator) ---
|
| 318 |
def process_uploaded_files(files_list, progress=gr.Progress(track_tqdm=True)):
|
| 319 |
-
global processed_files_data, person_profiles
|
| 320 |
processed_files_data = []
|
| 321 |
person_profiles = {}
|
| 322 |
-
|
| 323 |
if not OPENROUTER_API_KEY:
|
| 324 |
yield (
|
| 325 |
[["N/A", "ERROR", "OpenRouter API Key not configured.", "N/A", "N/A", "N/A", "N/A", "N/A"]],
|
|
@@ -327,74 +273,62 @@ def process_uploaded_files(files_list, progress=gr.Progress(track_tqdm=True)):
|
|
| 327 |
"{}", "API Key Missing. Processing halted."
|
| 328 |
)
|
| 329 |
return
|
| 330 |
-
|
| 331 |
if not files_list:
|
| 332 |
yield ([], "No files uploaded.", "{}", "Upload files to begin.")
|
| 333 |
return
|
| 334 |
-
|
| 335 |
-
# Initialize processed_files_data
|
| 336 |
for i, file_obj in enumerate(files_list):
|
| 337 |
doc_uid = str(uuid.uuid4())
|
| 338 |
processed_files_data.append({
|
| 339 |
"doc_id": doc_uid,
|
| 340 |
-
"filename": os.path.basename(file_obj.name
|
| 341 |
-
"filepath": file_obj.name,
|
| 342 |
"status": "Queued",
|
| 343 |
"ocr_json": None,
|
| 344 |
"entities": None,
|
| 345 |
"assigned_person_key": None
|
| 346 |
})
|
| 347 |
-
|
| 348 |
initial_df_data = format_dataframe_data(processed_files_data)
|
| 349 |
initial_persons_md = format_persons_markdown(person_profiles, processed_files_data)
|
| 350 |
yield (initial_df_data, initial_persons_md, "{}", f"Initialized. Found {len(files_list)} files.")
|
| 351 |
-
|
| 352 |
-
# Iterate and process each file
|
| 353 |
for i, file_data_item in enumerate(progress.tqdm(processed_files_data, desc="Processing Documents")):
|
| 354 |
current_doc_id = file_data_item["doc_id"]
|
| 355 |
current_filename = file_data_item["filename"]
|
| 356 |
-
|
| 357 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
file_data_item["status"] = "OCR in Progress..."
|
| 359 |
df_data = format_dataframe_data(processed_files_data)
|
| 360 |
-
persons_md = format_persons_markdown(person_profiles, processed_files_data)
|
| 361 |
yield (df_data, persons_md, "{}", f"({i+1}/{len(processed_files_data)}) OCR for: {current_filename}")
|
| 362 |
-
|
| 363 |
ocr_result = call_openrouter_ocr(file_data_item["filepath"])
|
| 364 |
-
file_data_item["ocr_json"] = ocr_result
|
| 365 |
-
|
| 366 |
if "error" in ocr_result:
|
| 367 |
-
file_data_item["status"] = f"OCR Error: {ocr_result['error'][:50]}..."
|
| 368 |
df_data = format_dataframe_data(processed_files_data)
|
| 369 |
yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) OCR Error on {current_filename}")
|
| 370 |
-
continue
|
| 371 |
-
|
| 372 |
file_data_item["status"] = "OCR Done. Extracting Entities..."
|
| 373 |
df_data = format_dataframe_data(processed_files_data)
|
| 374 |
yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) OCR Done for {current_filename}")
|
| 375 |
-
|
| 376 |
-
# 2. Entity Extraction
|
| 377 |
entities = extract_entities_from_ocr(ocr_result)
|
| 378 |
file_data_item["entities"] = entities
|
| 379 |
file_data_item["status"] = "Entities Extracted. Classifying..."
|
| 380 |
-
df_data = format_dataframe_data(processed_files_data)
|
| 381 |
yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) Entities for {current_filename}")
|
| 382 |
-
|
| 383 |
-
# 3. Person Classification / Linking
|
| 384 |
person_key = get_person_id_and_update_profiles(current_doc_id, entities, person_profiles)
|
| 385 |
file_data_item["assigned_person_key"] = person_key
|
| 386 |
file_data_item["status"] = "Classified"
|
| 387 |
-
|
| 388 |
df_data = format_dataframe_data(processed_files_data)
|
| 389 |
-
persons_md = format_persons_markdown(person_profiles, processed_files_data)
|
| 390 |
yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) Classified {current_filename} -> {person_key}")
|
| 391 |
-
|
| 392 |
final_df_data = format_dataframe_data(processed_files_data)
|
| 393 |
final_persons_md = format_persons_markdown(person_profiles, processed_files_data)
|
| 394 |
yield (final_df_data, final_persons_md, "{}", f"All {len(processed_files_data)} documents processed.")
|
| 395 |
|
| 396 |
-
|
| 397 |
-
# --- Gradio UI Layout ---
|
| 398 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 399 |
gr.Markdown("# 📄 Intelligent Document Processor & Classifier")
|
| 400 |
gr.Markdown(
|
|
@@ -402,58 +336,56 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
| 402 |
"The system will perform OCR, attempt to extract key entities, and classify documents by the person they belong to.**\n"
|
| 403 |
"Ensure `OPENROUTER_API_KEY` is set as a Secret in your Hugging Face Space."
|
| 404 |
)
|
| 405 |
-
|
| 406 |
if not OPENROUTER_API_KEY:
|
| 407 |
gr.Markdown("<h3 style='color:red;'>⚠️ ERROR: `OPENROUTER_API_KEY` is not set in Space Secrets! OCR will fail.</h3>")
|
| 408 |
-
|
| 409 |
with gr.Row():
|
| 410 |
with gr.Column(scale=1):
|
| 411 |
-
files_input = gr.Files(label="Upload Document Images (Bulk)", file_count="multiple", type="filepath")
|
| 412 |
-
process_button = gr.Button("Process Uploaded Documents", variant="primary")
|
| 413 |
overall_status_textbox = gr.Textbox(label="Overall Progress", interactive=False, lines=1)
|
| 414 |
-
|
| 415 |
gr.Markdown("---")
|
| 416 |
gr.Markdown("## Document Processing Details")
|
| 417 |
-
# "ID", "Filename", "Status", "Detected Type", "Extracted Name", "Extracted DOB", "Main ID", "Person Key"
|
| 418 |
dataframe_headers = ["Doc ID (short)", "Filename", "Status", "Detected Type", "Name", "DOB", "Passport No.", "Assigned Person Key"]
|
| 419 |
document_status_df = gr.Dataframe(
|
| 420 |
headers=dataframe_headers,
|
| 421 |
-
datatype=["str"] * len(dataframe_headers),
|
| 422 |
label="Individual Document Status & Extracted Entities",
|
| 423 |
-
row_count=(
|
| 424 |
col_count=(len(dataframe_headers), "fixed"),
|
| 425 |
wrap=True
|
| 426 |
)
|
| 427 |
-
|
| 428 |
ocr_json_output = gr.Code(label="Selected Document OCR JSON", language="json", interactive=False)
|
| 429 |
-
|
| 430 |
gr.Markdown("---")
|
| 431 |
person_classification_output_md = gr.Markdown("## Classified Persons & Documents\nNo persons identified yet.")
|
| 432 |
-
|
| 433 |
-
# Event Handlers
|
| 434 |
process_button.click(
|
| 435 |
fn=process_uploaded_files,
|
| 436 |
inputs=[files_input],
|
| 437 |
outputs=[
|
| 438 |
document_status_df,
|
| 439 |
person_classification_output_md,
|
| 440 |
-
ocr_json_output,
|
| 441 |
overall_status_textbox
|
| 442 |
]
|
| 443 |
)
|
| 444 |
-
|
| 445 |
@document_status_df.select(inputs=None, outputs=ocr_json_output, show_progress="hidden")
|
| 446 |
def display_selected_ocr(evt: gr.SelectData):
|
| 447 |
-
if evt.index is None or evt.index[0] is None:
|
| 448 |
-
return "{}"
|
| 449 |
-
|
| 450 |
selected_row_index = evt.index[0]
|
| 451 |
-
|
|
|
|
|
|
|
| 452 |
selected_doc_data = processed_files_data[selected_row_index]
|
| 453 |
-
if selected_doc_data and selected_doc_data
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
|
| 458 |
if __name__ == "__main__":
|
| 459 |
-
demo.queue().launch(debug=True,share=
|
|
|
|
| 15 |
|
| 16 |
# --- Global State (managed within Gradio's session if possible, or module-level for simplicity here) ---
|
| 17 |
# This will be reset each time the processing function is called.
|
|
|
|
| 18 |
processed_files_data = [] # Stores dicts for each file's details and status
|
| 19 |
person_profiles = {} # Stores dicts for each identified person and their documents
|
| 20 |
|
| 21 |
# --- Helper Functions ---
|
| 22 |
|
| 23 |
def extract_json_from_text(text):
|
|
|
|
|
|
|
|
|
|
| 24 |
if not text:
|
| 25 |
return {"error": "Empty text provided for JSON extraction."}
|
|
|
|
|
|
|
| 26 |
match_block = re.search(r"```json\s*(\{.*?\})\s*```", text, re.DOTALL | re.IGNORECASE)
|
| 27 |
if match_block:
|
| 28 |
json_str = match_block.group(1)
|
| 29 |
else:
|
|
|
|
| 30 |
text_stripped = text.strip()
|
| 31 |
if text_stripped.startswith("`") and text_stripped.endswith("`"):
|
| 32 |
json_str = text_stripped[1:-1]
|
| 33 |
else:
|
| 34 |
+
json_str = text_stripped
|
|
|
|
| 35 |
try:
|
| 36 |
return json.loads(json_str)
|
| 37 |
except json.JSONDecodeError as e:
|
|
|
|
| 38 |
try:
|
| 39 |
first_brace = json_str.find('{')
|
| 40 |
last_brace = json_str.rfind('}')
|
|
|
|
| 46 |
except json.JSONDecodeError as e2:
|
| 47 |
return {"error": f"Invalid JSON structure after attempting substring: {str(e2)}", "original_text": text}
|
| 48 |
|
|
|
|
| 49 |
def get_ocr_prompt():
|
| 50 |
return f"""You are an advanced OCR and information extraction AI.
|
| 51 |
Your task is to meticulously analyze this image and extract all relevant information.
|
|
|
|
| 79 |
try:
|
| 80 |
with open(image_filepath, "rb") as f:
|
| 81 |
encoded_image = base64.b64encode(f.read()).decode("utf-8")
|
|
|
|
|
|
|
| 82 |
mime_type = "image/jpeg"
|
| 83 |
if image_filepath.lower().endswith(".png"):
|
| 84 |
mime_type = "image/png"
|
| 85 |
elif image_filepath.lower().endswith(".webp"):
|
| 86 |
mime_type = "image/webp"
|
|
|
|
| 87 |
data_url = f"data:{mime_type};base64,{encoded_image}"
|
| 88 |
prompt_text = get_ocr_prompt()
|
|
|
|
| 89 |
payload = {
|
| 90 |
"model": IMAGE_MODEL,
|
| 91 |
"messages": [
|
|
|
|
| 97 |
]
|
| 98 |
}
|
| 99 |
],
|
| 100 |
+
"max_tokens": 3500,
|
| 101 |
"temperature": 0.1,
|
| 102 |
}
|
| 103 |
headers = {
|
| 104 |
"Authorization": f"Bearer {OPENROUTER_API_KEY}",
|
| 105 |
"Content-Type": "application/json",
|
| 106 |
+
"HTTP-Referer": "https://huggingface.co/spaces/YOUR_SPACE",
|
| 107 |
+
"X-Title": "Gradio Document Processor"
|
| 108 |
}
|
| 109 |
+
response = requests.post(OPENROUTER_API_URL, headers=headers, json=payload, timeout=180)
|
|
|
|
| 110 |
response.raise_for_status()
|
| 111 |
result = response.json()
|
|
|
|
| 112 |
if "choices" in result and result["choices"]:
|
| 113 |
raw_content = result["choices"][0]["message"]["content"]
|
| 114 |
return extract_json_from_text(raw_content)
|
| 115 |
else:
|
| 116 |
return {"error": "No 'choices' in API response from OpenRouter.", "details": result}
|
|
|
|
| 117 |
except requests.exceptions.Timeout:
|
| 118 |
return {"error": "API request timed out."}
|
| 119 |
except requests.exceptions.RequestException as e:
|
|
|
|
| 125 |
return {"error": f"An unexpected error occurred during OCR: {str(e)}"}
|
| 126 |
|
| 127 |
def extract_entities_from_ocr(ocr_json):
|
| 128 |
+
if not ocr_json or "extracted_fields" not in ocr_json or not isinstance(ocr_json.get("extracted_fields"), dict):
|
| 129 |
+
doc_type_from_ocr = "Unknown"
|
| 130 |
+
if isinstance(ocr_json, dict): # ocr_json itself might be an error dict
|
| 131 |
+
doc_type_from_ocr = ocr_json.get("document_type_detected", "Unknown (error in OCR)")
|
| 132 |
+
return {"name": None, "dob": None, "passport_no": None, "doc_type": doc_type_from_ocr}
|
| 133 |
|
| 134 |
fields = ocr_json["extracted_fields"]
|
| 135 |
doc_type = ocr_json.get("document_type_detected", "Unknown")
|
|
|
|
|
|
|
| 136 |
name_keys = ["full name", "name", "account holder name", "guest name"]
|
| 137 |
dob_keys = ["date of birth", "dob"]
|
| 138 |
passport_keys = ["document number", "passport number"]
|
|
|
|
| 139 |
extracted_name = None
|
| 140 |
for key in name_keys:
|
| 141 |
for field_key, value in fields.items():
|
| 142 |
if key == field_key.lower():
|
| 143 |
extracted_name = str(value) if value else None
|
| 144 |
break
|
| 145 |
+
if extracted_name: break
|
|
|
|
|
|
|
| 146 |
extracted_dob = None
|
| 147 |
for key in dob_keys:
|
| 148 |
for field_key, value in fields.items():
|
| 149 |
if key == field_key.lower():
|
| 150 |
extracted_dob = str(value) if value else None
|
| 151 |
break
|
| 152 |
+
if extracted_dob: break
|
|
|
|
|
|
|
| 153 |
extracted_passport_no = None
|
| 154 |
for key in passport_keys:
|
| 155 |
for field_key, value in fields.items():
|
| 156 |
if key == field_key.lower():
|
| 157 |
+
extracted_passport_no = str(value).replace(" ", "").upper() if value else None
|
| 158 |
break
|
| 159 |
+
if extracted_passport_no: break
|
|
|
|
|
|
|
| 160 |
return {
|
| 161 |
"name": extracted_name,
|
| 162 |
"dob": extracted_dob,
|
|
|
|
| 169 |
return "".join(filter(str.isalnum, name)).lower()
|
| 170 |
|
| 171 |
def get_person_id_and_update_profiles(doc_id, entities, current_persons_data):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
passport_no = entities.get("passport_no")
|
| 173 |
name = entities.get("name")
|
| 174 |
dob = entities.get("dob")
|
|
|
|
|
|
|
| 175 |
if passport_no:
|
| 176 |
for p_key, p_data in current_persons_data.items():
|
| 177 |
if passport_no in p_data.get("passport_numbers", set()):
|
| 178 |
p_data["doc_ids"].add(doc_id)
|
|
|
|
| 179 |
if name and not p_data.get("canonical_name"): p_data["canonical_name"] = name
|
| 180 |
if dob and not p_data.get("canonical_dob"): p_data["canonical_dob"] = dob
|
| 181 |
return p_key
|
| 182 |
+
new_person_key = f"person_{passport_no}"
|
|
|
|
| 183 |
current_persons_data[new_person_key] = {
|
| 184 |
+
"canonical_name": name, "canonical_dob": dob,
|
|
|
|
| 185 |
"names": {normalize_name(name)} if name else set(),
|
| 186 |
"dobs": {dob} if dob else set(),
|
| 187 |
+
"passport_numbers": {passport_no}, "doc_ids": {doc_id},
|
|
|
|
| 188 |
"display_name": name or f"Person (ID: {passport_no})"
|
| 189 |
}
|
| 190 |
return new_person_key
|
|
|
|
|
|
|
| 191 |
if name and dob:
|
| 192 |
norm_name = normalize_name(name)
|
| 193 |
composite_key_nd = f"{norm_name}_{dob}"
|
| 194 |
for p_key, p_data in current_persons_data.items():
|
|
|
|
| 195 |
if norm_name in p_data.get("names", set()) and dob in p_data.get("dobs", set()):
|
| 196 |
p_data["doc_ids"].add(doc_id)
|
| 197 |
return p_key
|
|
|
|
| 198 |
new_person_key = f"person_{composite_key_nd}_{str(uuid.uuid4())[:4]}"
|
| 199 |
current_persons_data[new_person_key] = {
|
| 200 |
+
"canonical_name": name, "canonical_dob": dob,
|
| 201 |
+
"names": {norm_name}, "dobs": {dob},
|
| 202 |
+
"passport_numbers": set(), "doc_ids": {doc_id},
|
|
|
|
|
|
|
|
|
|
| 203 |
"display_name": name
|
| 204 |
}
|
| 205 |
return new_person_key
|
|
|
|
|
|
|
| 206 |
if name:
|
| 207 |
norm_name = normalize_name(name)
|
|
|
|
|
|
|
| 208 |
new_person_key = f"person_{norm_name}_{str(uuid.uuid4())[:4]}"
|
| 209 |
current_persons_data[new_person_key] = {
|
| 210 |
"canonical_name": name, "canonical_dob": None,
|
|
|
|
| 212 |
"doc_ids": {doc_id}, "display_name": name
|
| 213 |
}
|
| 214 |
return new_person_key
|
|
|
|
|
|
|
| 215 |
generic_person_key = f"unidentified_person_{str(uuid.uuid4())[:6]}"
|
| 216 |
current_persons_data[generic_person_key] = {
|
| 217 |
"canonical_name": "Unknown", "canonical_dob": None,
|
|
|
|
| 220 |
}
|
| 221 |
return generic_person_key
|
| 222 |
|
|
|
|
| 223 |
def format_dataframe_data(current_files_data):
|
|
|
|
|
|
|
| 224 |
df_rows = []
|
| 225 |
for f_data in current_files_data:
|
| 226 |
+
entities = f_data.get("entities") or {} # CORRECTED LINE HERE
|
| 227 |
df_rows.append([
|
| 228 |
+
f_data.get("doc_id", "N/A")[:8],
|
| 229 |
+
f_data.get("filename", "N/A"),
|
| 230 |
+
f_data.get("status", "N/A"),
|
| 231 |
entities.get("doc_type", "N/A"),
|
| 232 |
entities.get("name", "N/A"),
|
| 233 |
entities.get("dob", "N/A"),
|
|
|
|
| 239 |
def format_persons_markdown(current_persons_data, current_files_data):
|
| 240 |
if not current_persons_data:
|
| 241 |
return "No persons identified yet."
|
|
|
|
| 242 |
md_parts = ["## Classified Persons & Documents\n"]
|
| 243 |
for p_key, p_data in current_persons_data.items():
|
| 244 |
display_name = p_data.get('display_name', p_key)
|
| 245 |
md_parts.append(f"### Person: {display_name} (Profile Key: {p_key})")
|
| 246 |
if p_data.get("canonical_dob"): md_parts.append(f"* DOB: {p_data['canonical_dob']}")
|
| 247 |
if p_data.get("passport_numbers"): md_parts.append(f"* Passport(s): {', '.join(p_data['passport_numbers'])}")
|
|
|
|
| 248 |
md_parts.append("* Documents:")
|
| 249 |
doc_ids_for_person = p_data.get("doc_ids", set())
|
| 250 |
if doc_ids_for_person:
|
| 251 |
for doc_id in doc_ids_for_person:
|
|
|
|
| 252 |
doc_detail = next((f for f in current_files_data if f["doc_id"] == doc_id), None)
|
| 253 |
if doc_detail:
|
| 254 |
+
filename = doc_detail.get("filename", "Unknown File")
|
| 255 |
+
doc_entities = doc_detail.get("entities") or {}
|
| 256 |
+
doc_type = doc_entities.get("doc_type", "Unknown Type")
|
| 257 |
md_parts.append(f" - {filename} (`{doc_type}`)")
|
| 258 |
else:
|
| 259 |
+
md_parts.append(f" - Document ID: {doc_id[:8]} (details error)")
|
| 260 |
else:
|
| 261 |
md_parts.append(" - No documents currently assigned.")
|
| 262 |
md_parts.append("\n---\n")
|
| 263 |
return "\n".join(md_parts)
|
| 264 |
|
|
|
|
| 265 |
def process_uploaded_files(files_list, progress=gr.Progress(track_tqdm=True)):
|
| 266 |
+
global processed_files_data, person_profiles
|
| 267 |
processed_files_data = []
|
| 268 |
person_profiles = {}
|
|
|
|
| 269 |
if not OPENROUTER_API_KEY:
|
| 270 |
yield (
|
| 271 |
[["N/A", "ERROR", "OpenRouter API Key not configured.", "N/A", "N/A", "N/A", "N/A", "N/A"]],
|
|
|
|
| 273 |
"{}", "API Key Missing. Processing halted."
|
| 274 |
)
|
| 275 |
return
|
|
|
|
| 276 |
if not files_list:
|
| 277 |
yield ([], "No files uploaded.", "{}", "Upload files to begin.")
|
| 278 |
return
|
|
|
|
|
|
|
| 279 |
for i, file_obj in enumerate(files_list):
|
| 280 |
doc_uid = str(uuid.uuid4())
|
| 281 |
processed_files_data.append({
|
| 282 |
"doc_id": doc_uid,
|
| 283 |
+
"filename": os.path.basename(file_obj.name if hasattr(file_obj, 'name') else f"file_{i+1}.unknown"),
|
| 284 |
+
"filepath": file_obj.name if hasattr(file_obj, 'name') else None, # file_obj itself is filepath if from gr.Files type="filepath"
|
| 285 |
"status": "Queued",
|
| 286 |
"ocr_json": None,
|
| 287 |
"entities": None,
|
| 288 |
"assigned_person_key": None
|
| 289 |
})
|
|
|
|
| 290 |
initial_df_data = format_dataframe_data(processed_files_data)
|
| 291 |
initial_persons_md = format_persons_markdown(person_profiles, processed_files_data)
|
| 292 |
yield (initial_df_data, initial_persons_md, "{}", f"Initialized. Found {len(files_list)} files.")
|
|
|
|
|
|
|
| 293 |
for i, file_data_item in enumerate(progress.tqdm(processed_files_data, desc="Processing Documents")):
|
| 294 |
current_doc_id = file_data_item["doc_id"]
|
| 295 |
current_filename = file_data_item["filename"]
|
| 296 |
+
if not file_data_item["filepath"]: # Check if filepath is valid
|
| 297 |
+
file_data_item["status"] = "Error: Invalid file path"
|
| 298 |
+
df_data = format_dataframe_data(processed_files_data)
|
| 299 |
+
persons_md = format_persons_markdown(person_profiles, processed_files_data)
|
| 300 |
+
yield(df_data, persons_md, "{}", f"({i+1}/{len(processed_files_data)}) Error with file {current_filename}")
|
| 301 |
+
continue
|
| 302 |
+
|
| 303 |
file_data_item["status"] = "OCR in Progress..."
|
| 304 |
df_data = format_dataframe_data(processed_files_data)
|
| 305 |
+
persons_md = format_persons_markdown(person_profiles, processed_files_data)
|
| 306 |
yield (df_data, persons_md, "{}", f"({i+1}/{len(processed_files_data)}) OCR for: {current_filename}")
|
|
|
|
| 307 |
ocr_result = call_openrouter_ocr(file_data_item["filepath"])
|
| 308 |
+
file_data_item["ocr_json"] = ocr_result
|
|
|
|
| 309 |
if "error" in ocr_result:
|
| 310 |
+
file_data_item["status"] = f"OCR Error: {str(ocr_result['error'])[:50]}..."
|
| 311 |
df_data = format_dataframe_data(processed_files_data)
|
| 312 |
yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) OCR Error on {current_filename}")
|
| 313 |
+
continue
|
|
|
|
| 314 |
file_data_item["status"] = "OCR Done. Extracting Entities..."
|
| 315 |
df_data = format_dataframe_data(processed_files_data)
|
| 316 |
yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) OCR Done for {current_filename}")
|
|
|
|
|
|
|
| 317 |
entities = extract_entities_from_ocr(ocr_result)
|
| 318 |
file_data_item["entities"] = entities
|
| 319 |
file_data_item["status"] = "Entities Extracted. Classifying..."
|
| 320 |
+
df_data = format_dataframe_data(processed_files_data)
|
| 321 |
yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) Entities for {current_filename}")
|
|
|
|
|
|
|
| 322 |
person_key = get_person_id_and_update_profiles(current_doc_id, entities, person_profiles)
|
| 323 |
file_data_item["assigned_person_key"] = person_key
|
| 324 |
file_data_item["status"] = "Classified"
|
|
|
|
| 325 |
df_data = format_dataframe_data(processed_files_data)
|
| 326 |
+
persons_md = format_persons_markdown(person_profiles, processed_files_data)
|
| 327 |
yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) Classified {current_filename} -> {person_key}")
|
|
|
|
| 328 |
final_df_data = format_dataframe_data(processed_files_data)
|
| 329 |
final_persons_md = format_persons_markdown(person_profiles, processed_files_data)
|
| 330 |
yield (final_df_data, final_persons_md, "{}", f"All {len(processed_files_data)} documents processed.")
|
| 331 |
|
|
|
|
|
|
|
| 332 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 333 |
gr.Markdown("# 📄 Intelligent Document Processor & Classifier")
|
| 334 |
gr.Markdown(
|
|
|
|
| 336 |
"The system will perform OCR, attempt to extract key entities, and classify documents by the person they belong to.**\n"
|
| 337 |
"Ensure `OPENROUTER_API_KEY` is set as a Secret in your Hugging Face Space."
|
| 338 |
)
|
|
|
|
| 339 |
if not OPENROUTER_API_KEY:
|
| 340 |
gr.Markdown("<h3 style='color:red;'>⚠️ ERROR: `OPENROUTER_API_KEY` is not set in Space Secrets! OCR will fail.</h3>")
|
|
|
|
| 341 |
with gr.Row():
|
| 342 |
with gr.Column(scale=1):
|
| 343 |
+
files_input = gr.Files(label="Upload Document Images (Bulk)", file_count="multiple", type="filepath") # Using filepath
|
| 344 |
+
process_button = gr.Button("🚀 Process Uploaded Documents", variant="primary")
|
| 345 |
overall_status_textbox = gr.Textbox(label="Overall Progress", interactive=False, lines=1)
|
|
|
|
| 346 |
gr.Markdown("---")
|
| 347 |
gr.Markdown("## Document Processing Details")
|
|
|
|
| 348 |
dataframe_headers = ["Doc ID (short)", "Filename", "Status", "Detected Type", "Name", "DOB", "Passport No.", "Assigned Person Key"]
|
| 349 |
document_status_df = gr.Dataframe(
|
| 350 |
headers=dataframe_headers,
|
| 351 |
+
datatype=["str"] * len(dataframe_headers),
|
| 352 |
label="Individual Document Status & Extracted Entities",
|
| 353 |
+
row_count=(1, "dynamic"), # Start with 1 row, dynamically grows
|
| 354 |
col_count=(len(dataframe_headers), "fixed"),
|
| 355 |
wrap=True
|
| 356 |
)
|
|
|
|
| 357 |
ocr_json_output = gr.Code(label="Selected Document OCR JSON", language="json", interactive=False)
|
|
|
|
| 358 |
gr.Markdown("---")
|
| 359 |
person_classification_output_md = gr.Markdown("## Classified Persons & Documents\nNo persons identified yet.")
|
|
|
|
|
|
|
| 360 |
process_button.click(
|
| 361 |
fn=process_uploaded_files,
|
| 362 |
inputs=[files_input],
|
| 363 |
outputs=[
|
| 364 |
document_status_df,
|
| 365 |
person_classification_output_md,
|
| 366 |
+
ocr_json_output,
|
| 367 |
overall_status_textbox
|
| 368 |
]
|
| 369 |
)
|
|
|
|
| 370 |
@document_status_df.select(inputs=None, outputs=ocr_json_output, show_progress="hidden")
|
| 371 |
def display_selected_ocr(evt: gr.SelectData):
|
| 372 |
+
if evt.index is None or evt.index[0] is None:
|
| 373 |
+
return "{}"
|
|
|
|
| 374 |
selected_row_index = evt.index[0]
|
| 375 |
+
# Ensure processed_files_data is accessible here. If it's truly global, it should be.
|
| 376 |
+
# For safety, one might pass it or make it part of a class if this were more complex.
|
| 377 |
+
if 0 <= selected_row_index < len(processed_files_data):
|
| 378 |
selected_doc_data = processed_files_data[selected_row_index]
|
| 379 |
+
if selected_doc_data and selected_doc_data.get("ocr_json"):
|
| 380 |
+
# Check if ocr_json is already a dict, if not, try to parse (though it should be)
|
| 381 |
+
ocr_data_to_display = selected_doc_data["ocr_json"]
|
| 382 |
+
if isinstance(ocr_data_to_display, str): # Should not happen if stored correctly
|
| 383 |
+
try:
|
| 384 |
+
ocr_data_to_display = json.loads(ocr_data_to_display)
|
| 385 |
+
except json.JSONDecodeError:
|
| 386 |
+
return json.dumps({"error": "Stored OCR data is not valid JSON string."}, indent=2)
|
| 387 |
+
return json.dumps(ocr_data_to_display, indent=2, ensure_ascii=False)
|
| 388 |
+
return json.dumps({ "message": "No OCR data found for selected row or selection out of bounds (check if processing is complete). Current rows: " + str(len(processed_files_data))}, indent=2)
|
| 389 |
|
| 390 |
if __name__ == "__main__":
|
| 391 |
+
demo.queue().launch(debug=True, share=os.environ.get("GRADIO_SHARE", "true").lower() == "true")
|