Spaces:
Sleeping
Sleeping
| from doctr.models import detection_predictor, recognition_predictor | |
| from doctr.io import DocumentFile | |
| from surya.recognition import RecognitionPredictor | |
| from surya.detection import DetectionPredictor | |
| from PIL import Image | |
| # from functools import lru_cache | |
| from torchvision import models | |
| from typing import List | |
| from fastapi import HTTPException | |
| from data_models import Citizenship | |
| import json | |
| import torchvision.transforms as transforms | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import cv2 | |
| import regex as re | |
| import requests | |
| # import os | |
| import pickle | |
| # Character sets | |
| CHARACTER_NUM = "0123456789-" | |
| CHARACTER_LETTER = ''' "()-./0123456789:?ABCDEFGHIKLMNOPQRSTUWYabcdefghijklmnoprstuvwyँंःअआइईउऊऋऌऍऎएऐऑऒओऔकखगघङचछजझञटठडढणतथदधनऩपफबभमयरऱलळऴवशषसह़ऽािीुूृॄॅॆेैॉॊोौ्ॐ॒॑॓॔क़ख़ग़ज़ड़ढ़फ़य़ॠॢ।॥०१२३४५६७८९॰ॱॲॻॼॽॾ^''' #"()-./0123456789:?ABCDEFGHIKLMNOPQRSTUWYabcdefghijklmnoprstuvwyँंःअआइईउऊऋऌऍऎएऐऑऒओऔकखगघङचछजझञटठडढणतथदधनऩपफबभमयरऱलळऴवशषसह़ऽािीुूृॄॅॆेैॉॊोौ्ॐ॒॑॓॔क़ख़ग़ज़ड़ढ़फ़य़ॠॢ।॥०१२३४५६७८९॰ॱॲॻॼॽॾ^" | |
| # Model paths - these should be configurable | |
| MODEL_PATHS = { | |
| 'dev_digits': "models/devnagri_digits_20k_v2.pth", | |
| 'roman_digits': "models/roman_digits_20k_v5.pth", | |
| 'dev_letter': "models/small_devnagari_letter.pth", | |
| 'classify_ne': "models/nepali_english_classifier.pth" | |
| } | |
| # Use GPU if available | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| class ResNetClassifier(nn.Module): | |
| def __init__(self, num_classes=2): | |
| super(ResNetClassifier, self).__init__() | |
| self.base_model = models.resnet50(weights='IMAGENET1K_V2') # Pre-trained ResNet-50 | |
| for param in self.base_model.parameters(): | |
| param.requires_grad = False # Freeze base model | |
| num_ftrs = self.base_model.fc.in_features | |
| self.base_model.fc = nn.Sequential( | |
| nn.Linear(num_ftrs, 128), | |
| nn.ReLU(), | |
| nn.Linear(128, num_classes) | |
| ) | |
| def forward(self, x): | |
| return self.base_model(x) | |
| # Define the CRNN model | |
| class CRNN(nn.Module): | |
| def __init__(self, num_classes, input_size=(1, 64, 256)): | |
| super(CRNN, self).__init__() | |
| self.conv_block = nn.Sequential( | |
| nn.Conv2d(input_size[0], 64, kernel_size=3, stride=1, padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(), | |
| nn.MaxPool2d(kernel_size=2, stride=2), # 64x128 | |
| nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(), | |
| nn.MaxPool2d(kernel_size=2, stride=2), # 32x64 | |
| nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), | |
| nn.BatchNorm2d(256), | |
| nn.ReLU(), | |
| nn.MaxPool2d(kernel_size=2, stride=2), # 16x32 | |
| nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), | |
| nn.BatchNorm2d(512), | |
| nn.ReLU(), | |
| nn.MaxPool2d(kernel_size=2, stride=2) # 8x16 | |
| ) | |
| # Dimensions after conv: batch x 512 x 8 x 16 | |
| feature_height = input_size[1] // 16 # 64 -> 4 pools → 64/2^4 = 4 | |
| self.rnn = nn.LSTM( | |
| input_size=512 * feature_height, # 512 * 4 = 2048 | |
| hidden_size=128, | |
| num_layers=1, | |
| bidirectional=True, | |
| dropout=0.3, | |
| batch_first=True | |
| ) | |
| self.fc = nn.Linear(256, num_classes) # 256 for bidirectional | |
| def forward(self, x): | |
| x = self.conv_block(x) # (B, 512, H=4, W=16) | |
| b, c, h, w = x.size() | |
| x = x.permute(0, 3, 1, 2) # (B, W, C, H) | |
| x = x.contiguous().view(b, w, c * h) # (B, seq_len, input_size) | |
| x, _ = self.rnn(x) # (B, seq_len, 512) | |
| x = self.fc(x) # (B, seq_len, num_classes) | |
| return x | |
| class OCRModelManager: | |
| """ | |
| Singleton class to manage OCR models and prevent repeated loading | |
| """ | |
| _instance = None | |
| def __new__(cls): | |
| if cls._instance is None: | |
| cls._instance = super(OCRModelManager, cls).__new__(cls) | |
| cls._instance.models = {} | |
| cls._instance.char_maps = {} | |
| cls._instance.transforms = {} | |
| cls._instance.initialize_transforms() | |
| # Initialize doctr model once | |
| cls._instance.roman_letter_model = recognition_predictor(pretrained=True) | |
| return cls._instance | |
| def initialize_transforms(self): | |
| """Initialize standard transforms used across models""" | |
| self.transforms['standard'] = transforms.Compose([ | |
| transforms.Resize((64, 256)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5,), (0.5,)) | |
| ]) | |
| def get_model(self, model_type, character_set): | |
| """Get or load a model based on type""" | |
| if model_type not in self.models: | |
| if model_type not in MODEL_PATHS: | |
| raise ValueError(f"Unknown model type: {model_type}") | |
| # Create character to ID mapping | |
| self.char_maps[model_type] = { | |
| 'id_to_char': {i: c for i, c in enumerate(character_set)}, | |
| 'char_to_id': {c: i for i, c in enumerate(character_set)} | |
| } | |
| # Initialize and load model | |
| model = CRNN(num_classes=len(character_set)) | |
| model.load_state_dict(torch.load(MODEL_PATHS[model_type], map_location=DEVICE)) | |
| model.eval() # Set to evaluation mode | |
| model = model.to(DEVICE) | |
| self.models[model_type] = model | |
| return self.models[model_type], self.char_maps[model_type] | |
| def preprocess_image(self, image_path, model_type): | |
| """Preprocess image based on model type""" | |
| image = Image.open(image_path).convert('L') | |
| # Apply specific preprocessing based on model type | |
| if model_type != 'dev_letter': | |
| # Binarize the image for digit models | |
| image = image.point(lambda x: 0 if x < 128 else 255, 'L') | |
| # Resize to model input size | |
| image = image.resize((256, 64)) | |
| # Invert colors for dev_letter model | |
| if model_type == 'dev_letter': | |
| image = Image.eval(image, lambda x: 255 - x) | |
| # Apply transforms | |
| tensor_image = self.transforms['standard'](image).unsqueeze(0).to(DEVICE) | |
| return tensor_image | |
| def predict(self, image_path, model_type, character_set): | |
| """Make a prediction using the specified model""" | |
| # Get or load model | |
| model, char_map = self.get_model(model_type, character_set) | |
| # Preprocess image | |
| tensor_image = self.preprocess_image(image_path, model_type) | |
| # Run inference | |
| with torch.no_grad(): | |
| output = model(tensor_image) | |
| output = output.permute(1, 0, 2) # (seq_len, batch_size, num_classes) | |
| _, predicted = output.max(2) | |
| predicted = predicted.permute(1, 0) # (batch_size, seq_len) | |
| # Convert tokens to string | |
| predicted_str = ''.join([char_map['id_to_char'][i] for i in predicted[0].cpu().numpy()]) | |
| return predicted_str | |
| def predict_roman_letter(self, image_path): | |
| """Predict using the doctr model for Roman letters""" | |
| img = DocumentFile.from_images(image_path) | |
| result = self.roman_letter_model(img) | |
| # print(result) | |
| return result[0][0] | |
| # Initialize the model manager as a singleton | |
| ocr_manager = OCRModelManager() | |
| # Simplified API functions | |
| def dev_number(image_path): | |
| """Recognize Devanagari digits in an image""" | |
| return ocr_manager.predict(image_path, 'dev_digits', CHARACTER_NUM) | |
| def roman_number(image_path): | |
| """Recognize Roman digits in an image""" | |
| return ocr_manager.predict(image_path, 'roman_digits', CHARACTER_NUM) | |
| def dev_letter(image_path): | |
| """Recognize Devanagari letters in an image""" | |
| return ocr_manager.predict(image_path, 'dev_letter', CHARACTER_LETTER) | |
| def roman_letter(image_path): | |
| """Recognize Roman letters in an image""" | |
| return ocr_manager.predict_roman_letter(image_path) | |
| def predict_ne(image_path, device="cpu"): | |
| # load label encoder | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = ResNetClassifier(num_classes=4).to(device) | |
| # model.eval() | |
| transform = transforms.Compose([ | |
| transforms.Resize(256), # Resize shorter side to 256 | |
| transforms.CenterCrop(224), # Crop center 224x224 patch | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| image = Image.open(image_path).convert('RGB') | |
| image_tensor = transform(image).unsqueeze(0).to(device) | |
| # loading model weights/state_dict | |
| model.load_state_dict(torch.load('models/dev_roman_classifier.pth', map_location=device)) | |
| model.eval() | |
| # loading label encoder | |
| with open('models/dev_roman_label_encoder.pkl', 'rb') as f: | |
| le = pickle.load(f) | |
| with torch.no_grad(): | |
| output = model(image_tensor) | |
| _, predicted = torch.max(output, 1) | |
| return le.inverse_transform([predicted.item()])[0] | |
| doctr_detector = None | |
| surya_recognition_predictor = None | |
| surya_detection_predictor = None | |
| def initialize_detector(): | |
| global doctr_detector, surya_recognition_predictor, surya_detection_predictor | |
| if doctr_detector is None: | |
| doctr_detector = detection_predictor('db_mobilenet_v3_large', pretrained=True, assume_straight_pages=True, preserve_aspect_ratio=True) | |
| if surya_recognition_predictor is None: | |
| surya_recognition_predictor = RecognitionPredictor() | |
| if surya_detection_predictor is None: | |
| surya_detection_predictor = DetectionPredictor() | |
| return doctr_detector, surya_recognition_predictor, surya_detection_predictor | |
| def get_cleaned_boxes(out, page): | |
| h, w, _ = page.shape | |
| cleaned_boxes = [] | |
| for box in out[0]['words']: | |
| coords = np.array(box[:4]) # 4 corner points (normalized) | |
| coords *= np.array([w, h, w, h]) | |
| x1, y1, x2, y2 = coords | |
| x_thresh = 0.7 * page.shape[1] | |
| y_thresh = 0.3* page.shape[0] | |
| if x1> x_thresh and y1 < y_thresh: | |
| continue | |
| if (x2 - x1) * (y2 - y1) < 100: | |
| continue | |
| cleaned_boxes.append(coords.astype('int')) | |
| return cleaned_boxes | |
| # The most inefficient code in existence | |
| def merge_boxes_same_line(boxes, y_thresh=5, x_thresh=60): | |
| # Sort boxes first by x and then by y | |
| boxes = sorted(boxes, key=lambda b: (b[1],b[0])) | |
| # Trying make all boxes within certain threshold have the same y coordinate for sorting | |
| # Threshold for grouping rows | |
| row_threshold = 15 | |
| aligned_boxes = [] | |
| current_row = [] | |
| current_y = boxes[0][1] | |
| for box in boxes: | |
| x1, y1, x2, y2 = box | |
| if abs(y1 - current_y) <= row_threshold: | |
| current_row.append(box) | |
| else: | |
| # Align all y1 and y2 in the row | |
| avg_y1 = int(np.mean([b[1] for b in current_row])) | |
| avg_y2 = int(np.mean([b[3] for b in current_row])) | |
| aligned_boxes.extend([(b[0], avg_y1, b[2], avg_y2) for b in current_row]) | |
| current_row = [box] | |
| current_y = y1 | |
| # Handle the last row | |
| if current_row: | |
| avg_y1 = int(np.mean([b[1] for b in current_row])) | |
| avg_y2 = int(np.mean([b[3] for b in current_row])) | |
| aligned_boxes.extend([(b[0], avg_y1, b[2], avg_y2) for b in current_row]) | |
| # After aligning all boxes on y axis, re sort them | |
| aligned_boxes = sorted(aligned_boxes, key=lambda b: (b[1],b[0])) | |
| # Merge adjacent boxes within certain threshold | |
| merged = [] | |
| p_x1, p_y1, p_x2, p_y2 = aligned_boxes[0] | |
| for i in range(1,len(aligned_boxes)): | |
| x1, y1, x2, y2 = aligned_boxes[i] | |
| if abs(p_y1 - y1) < y_thresh and abs(x1 - p_x2) < x_thresh: | |
| p_x1 = min(p_x1, x1) | |
| p_y1 = min(p_y1, y1) | |
| p_x2 = max(p_x2, x2) | |
| p_y2 = max(p_y2, y2) | |
| else: | |
| merged.append([p_x1, p_y1, p_x2, p_y2]) | |
| p_x1, p_y1, p_x2, p_y2 = x1, y1, x2, y2 | |
| merged.append([p_x1, p_y1, p_x2, p_y2]) | |
| return np.array(merged) | |
| def ocr_citizenship(image_path: str) -> List[List[str]]: | |
| doctr_detector, surya_recognition_predictor, surya_detection_predictor = initialize_detector() | |
| page = cv2.imread(image_path) | |
| page = cv2.convertScaleAbs(page, alpha=1.5, beta=0) | |
| page = cv2.resize(page, (720,480)) | |
| out = doctr_detector([page]) | |
| cleaned_boxes = get_cleaned_boxes(out,page) | |
| merged = merge_boxes_same_line(cleaned_boxes) | |
| pattern = r'(नेपाली\s*नागरिकताको\s*प्रमाणपत्र){e<=6}' | |
| prev_y = 0 | |
| start = False | |
| first_start = True | |
| y_thresh = 5 | |
| text_combine = '' | |
| full_result = [] | |
| line_result = [] | |
| for boxes in merged[3:]: | |
| x1, y1, x2, y2 = boxes[0],boxes[1],boxes[2],boxes[3] | |
| crop = page[y1:y2,x1:x2] | |
| pil_image = Image.fromarray(crop).convert('L') | |
| # OCR PART | |
| langs = ["en",'ne'] | |
| predictions = surya_recognition_predictor(images=[pil_image], langs=[langs],det_predictor=surya_detection_predictor) | |
| text_combo = '' | |
| for text_line in predictions[0].text_lines: | |
| text_combo = text_combo + " " + text_line.text.strip() | |
| text_combo = text_combo.strip() | |
| # OCR PART END | |
| if not start: | |
| match = re.search(pattern, text_combo) | |
| if match: | |
| start = True | |
| continue | |
| if first_start: | |
| first_start = False | |
| prev_y = boxes[1] | |
| if y1 - prev_y > y_thresh: | |
| full_result.append(line_result) | |
| line_result = [] | |
| line_result.append(text_combo) | |
| prev_y = boxes[1] | |
| return full_result | |
| PARSE_PROMPT = "You are a parsing agent. Your task is to generate a json response from the given text corpus." | |
| def create_local_model(message, base_model): | |
| try: | |
| ollama_endpoint = "api/chat" | |
| url = f"https://aioverlords-amnil-internal-ollama.hf.space/proxy/{ollama_endpoint}" | |
| # Data to send in the POST request | |
| data = { | |
| "data": { | |
| "model": "aisingapore/Llama-SEA-LION-v3-8B-IT", | |
| "messages": message, | |
| "stream": False, | |
| "format": base_model.model_json_schema() | |
| } | |
| } | |
| response = requests.post(url, json=data) | |
| # Check the response | |
| if response.status_code == 200: | |
| print(f"Request Success:", response.json()) | |
| return json.loads(response.json()["message"]["content"]) | |
| # return response.json() | |
| else: | |
| print(f"Request Error:", response.status_code, response.text) | |
| raise HTTPException(status_code=response.status_code, detail=response.text) | |
| except HTTPException as http_exec: | |
| raise http_exec | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def perform_citizenship_ocr(image_path): | |
| try: | |
| unparsed_result = ocr_citizenship(image_path) | |
| message = [ | |
| {"role": "system", "content": PARSE_PROMPT}, | |
| {"role": "user", "content": f"Given Text: \n{unparsed_result}"}, | |
| ] | |
| return create_local_model(message, Citizenship) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) |