Spaces:
Sleeping
Sleeping
| # # ! pip uninstall -y tensorflow | |
| # # ! pip install "python-doctr[torch,viz]" | |
| # from fastapi import FastAPI, UploadFile, File | |
| # from fastapi.responses import JSONResponse | |
| # from utils import dev_number, roman_number, dev_letter, roman_letter | |
| # import tempfile | |
| # app = FastAPI() | |
| # @app.post("/ocr_dev_number/") | |
| # async def extract_dev_number(image: UploadFile = File(...)): | |
| # # Save uploaded image temporarily | |
| # with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp: | |
| # content = await image.read() | |
| # tmp.write(content) | |
| # tmp_path = tmp.name | |
| # # predict the image | |
| # predicted_str = dev_number(tmp_path) | |
| # # Return result as JSON | |
| # return JSONResponse(content={"predicted_str": predicted_str}) | |
| # @app.post("/ocr_roman_number/") | |
| # async def extract_roman_number(image: UploadFile = File(...)): | |
| # # Save uploaded image temporarily | |
| # with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp: | |
| # content = await image.read() | |
| # tmp.write(content) | |
| # tmp_path = tmp.name | |
| # # predict the image | |
| # predicted_str = roman_number(tmp_path) | |
| # # Return result as JSON | |
| # return JSONResponse(content={"predicted_str": predicted_str}) | |
| # @app.post("/ocr_dev_letter/") | |
| # async def extract_dev_letter(image: UploadFile = File(...)): | |
| # # Save uploaded image temporarily | |
| # with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp: | |
| # content = await image.read() | |
| # tmp.write(content) | |
| # tmp_path = tmp.name | |
| # # predict the image | |
| # predicted_str = dev_letter(tmp_path) | |
| # # Return result as JSON | |
| # return JSONResponse(content={"predicted_str": predicted_str}) | |
| # @app.post("/ocr_roman_letter/") | |
| # async def extract_roman_letter(image: UploadFile = File(...)): | |
| # # Save uploaded image temporarily | |
| # with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp: | |
| # content = await image.read() | |
| # tmp.write(content) | |
| # tmp_path = tmp.name | |
| # # predict the image | |
| # predicted_str = roman_letter(tmp_path) | |
| # # Return result as JSON | |
| # return JSONResponse(content={"predicted_str": predicted_str}) | |
| import os | |
| import tempfile | |
| from typing import Literal | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Depends | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| import shutil | |
| # Import from optimized utils | |
| from utils import dev_number, roman_number, dev_letter, roman_letter, predict_ne, perform_citizenship_ocr | |
| app = FastAPI( | |
| title="OCR API", | |
| description="API for optical character recognition of Roman and Devanagari text", | |
| version="1.0.0" | |
| ) | |
| class OCRResponse(BaseModel): | |
| """Response model for OCR endpoints""" | |
| predicted_str: str | |
| confidence: float = None # Optional confidence field | |
| # Helper function to handle file uploads consistently | |
| async def save_upload_file_tmp(upload_file: UploadFile) -> str: | |
| """Save an upload file to a temporary file and return the path""" | |
| try: | |
| # Create a temporary file with the appropriate suffix | |
| suffix = os.path.splitext(upload_file.filename)[1] if upload_file.filename else ".png" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: | |
| # Get the file content | |
| content = await upload_file.read() | |
| # Write content to temporary file | |
| tmp.write(content) | |
| tmp_path = tmp.name | |
| return tmp_path | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing file: {str(e)}") | |
| # Generic OCR function that can be reused across endpoints | |
| async def process_ocr_request( | |
| image: UploadFile = File(...), | |
| ocr_function=None | |
| ): | |
| """Process an OCR request using the specified OCR function""" | |
| if not ocr_function: | |
| raise HTTPException(status_code=500, detail="OCR function not specified") | |
| try: | |
| # Save uploaded image temporarily | |
| tmp_path = await save_upload_file_tmp(image) | |
| # Process the image with the specified OCR function | |
| result = ocr_function(tmp_path) | |
| # Clean up the temporary file | |
| os.unlink(tmp_path) | |
| # Handle different types of results (string vs doctr output) | |
| if isinstance(result, str): | |
| return JSONResponse(content={"predicted_str": result}) | |
| else: | |
| # For doctr results, extract the text (adapt as needed based on doctr output format) | |
| # This assumes roman_letter function returns a structure with extractable text | |
| extracted_text = " ".join([block.value for page in result.pages for block in page.blocks]) | |
| return JSONResponse(content={"predicted_str": extracted_text}) | |
| except Exception as e: | |
| # Ensure we clean up even if there's an error | |
| if 'tmp_path' in locals() and os.path.exists(tmp_path): | |
| os.unlink(tmp_path) | |
| raise HTTPException(status_code=500, detail=f"OCR processing error: {str(e)}") | |
| # Endpoints with minimal duplication | |
| async def extract_text( | |
| image: UploadFile = File(...), | |
| model_type: Literal["dev_number", "roman_number", "dev_letter", "roman_letter"] = "roman_letter" | |
| ): | |
| """ | |
| Generic OCR endpoint that can handle any supported recognition type. | |
| - **image**: Image file to process | |
| - **model_type**: Type of OCR to perform | |
| """ | |
| ocr_functions = { | |
| "dev_number": dev_number, | |
| "roman_number": roman_number, | |
| "dev_letter": dev_letter, | |
| "roman_letter": roman_letter, | |
| } | |
| if model_type not in ocr_functions: | |
| raise HTTPException(status_code=400, detail=f"Invalid model type: {model_type}") | |
| return await process_ocr_request(image, ocr_functions[model_type]) | |
| # For backward compatibility, keep the original endpoints | |
| async def extract_dev_number(image: UploadFile = File(...)): | |
| """Extract Devanagari numbers from an image""" | |
| return await process_ocr_request(image, dev_number) | |
| async def extract_roman_number(image: UploadFile = File(...)): | |
| """Extract Roman numbers from an image""" | |
| return await process_ocr_request(image, roman_number) | |
| async def extract_dev_letter(image: UploadFile = File(...)): | |
| """Extract Devanagari letters from an image""" | |
| return await process_ocr_request(image, dev_letter) | |
| async def extract_roman_letter(image: UploadFile = File(...)): | |
| """Extract Roman letters from an image""" | |
| return await process_ocr_request(image, roman_letter) | |
| async def classify_ne(image: UploadFile = File(...)): | |
| """Predict Named Entities from an image""" | |
| # Placeholder for Named Entity Recognition logic | |
| image_path = await save_upload_file_tmp(image) | |
| prediction = predict_ne( | |
| image_path=image_path, | |
| # model="models/nepali_english_classifier.pth", # Update with actual model path | |
| device="cpu" # or "cpu" | |
| ) | |
| # Implement the logic as per your requirements | |
| return JSONResponse(content={"predicted": prediction}) | |
| async def ocr_citizenship(image: UploadFile = File(...)): | |
| """OCR the provided Nepali Citizenship card""" | |
| image_path = await save_upload_file_tmp(image) | |
| prediction = perform_citizenship_ocr( | |
| image_path=image_path, | |
| ) | |
| return JSONResponse(content=prediction) | |
| # Health check endpoint | |
| async def health_check(): | |
| """Health check endpoint to verify the API is running""" | |
| return {"status": "healthy"} |