Spaces:
Running
Running
| import streamlit as st | |
| import os | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import torch | |
| from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
| from ultralytics import YOLO | |
| import Levenshtein | |
| # Page config | |
| st.set_page_config( | |
| page_title="Thai License Plate Detection", | |
| page_icon="🚗", | |
| layout="centered" | |
| ) | |
| # Initialize session state for models | |
| if 'models_loaded' not in st.session_state: | |
| st.session_state['models_loaded'] = False | |
| def load_ocr_models(): | |
| """Load OCR models with proper error handling""" | |
| try: | |
| # Set environment variables to suppress warnings | |
| os.environ['TOKENIZERS_PARALLELISM'] = 'false' | |
| # Load processor with specific config | |
| processor = TrOCRProcessor.from_pretrained( | |
| 'openthaigpt/thai-trocr', | |
| revision='main', | |
| use_auth_token=False, | |
| trust_remote_code=True, | |
| local_files_only=False | |
| ) | |
| # Load OCR model with specific config | |
| ocr_model = VisionEncoderDecoderModel.from_pretrained( | |
| 'openthaigpt/thai-trocr', | |
| revision='main', | |
| use_auth_token=False, | |
| trust_remote_code=True, | |
| local_files_only=False | |
| ) | |
| # Move model to CPU explicitly | |
| ocr_model = ocr_model.to('cpu') | |
| return processor, ocr_model | |
| except Exception as e: | |
| st.error(f"Error loading OCR models: {str(e)}") | |
| st.error("Detailed error information:") | |
| import traceback | |
| st.code(traceback.format_exc()) | |
| return None, None | |
| # Load models | |
| def load_models(): | |
| try: | |
| # Check if YOLO weights exist | |
| if not os.path.exists('best.pt'): | |
| st.error("YOLO model weights (best.pt) not found in the current directory!") | |
| return None, None, None | |
| # Load YOLO model | |
| try: | |
| yolo_model = YOLO('best.pt', task='detect') | |
| except Exception as yolo_error: | |
| st.error(f"Error loading YOLO model: {str(yolo_error)}") | |
| return None, None, None | |
| # Load OCR models | |
| processor, ocr_model = load_ocr_models() | |
| if processor is None or ocr_model is None: | |
| return None, None, None | |
| return processor, ocr_model, yolo_model | |
| except Exception as e: | |
| st.error(f"Error in model loading: {str(e)}") | |
| st.error("Detailed error information:") | |
| import traceback | |
| st.code(traceback.format_exc()) | |
| return None, None, None | |
| # Thai provinces list | |
| thai_provinces = [ | |
| "กรุงเทพมหานคร", "กระบี่", "กาญจนบุรี", "กาฬสินธุ์", "กำแพงเพชร", "ขอนแก่น", "จันทบุรี", "ฉะเชิงเทรา", | |
| "ชลบุรี", "ชัยนาท", "ชัยภูมิ", "ชุมพร", "เชียงราย", "เชียงใหม่", "ตรัง", "ตราด", "ตาก", "นครนายก", | |
| "นครปฐม", "นครพนม", "นครราชสีมา", "นครศรีธรรมราช", "นครสวรรค์", "นราธิวาส", "น่าน", "บึงกาฬ", | |
| "บุรีรัมย์", "ปทุมธานี", "ประจวบคีรีขันธ์", "ปราจีนบุรี", "ปัตตานี", "พะเยา", "พังงา", "พัทลุง", | |
| "พิจิตร", "พิษณุโลก", "เพชรบูรณ์", "เพชรบุรี", "แพร่", "ภูเก็ต", "มหาสารคาม", "มุกดาหาร", "แม่ฮ่องสอน", | |
| "ยโสธร", "ยะลา", "ร้อยเอ็ด", "ระนอง", "ระยอง", "ราชบุรี", "ลพบุรี", "ลำปาง", "ลำพูน", "เลย", | |
| "ศรีสะเกษ", "สกลนคร", "สงขลา", "สมุทรปราการ", "สมุทรสงคราม", "สมุทรสาคร", "สระแก้ว", "สระบุรี", | |
| "สิงห์บุรี", "สุโขทัย", "สุพรรณบุรี", "สุราษฎร์ธานี", "สุรินทร์", "หนองคาย", "หนองบัวลำภู", "อำนาจเจริญ", | |
| "อุดรธานี", "อุทัยธานี", "อุบลราชธานี", "อ่างทอง" | |
| ] | |
| def get_closest_province(input_text, provinces): | |
| min_distance = float('inf') | |
| closest_province = None | |
| for province in provinces: | |
| distance = Levenshtein.distance(input_text, province) | |
| if distance < min_distance: | |
| min_distance = distance | |
| closest_province = province | |
| return closest_province, min_distance | |
| def process_image(image, processor, ocr_model, yolo_model): | |
| CONF_THRESHOLD = 0.2 | |
| data = {"plate_number": "", "province": "", "raw_province": "", "plate_crop": None, "province_crop": None} | |
| # Convert PIL Image to cv2 format | |
| image = np.array(image) | |
| image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
| # Image enhancement | |
| lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB) | |
| l, a, b = cv2.split(lab) | |
| clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8)) | |
| cl = clahe.apply(l) | |
| enhanced = cv2.merge((cl,a,b)) | |
| image = cv2.cvtColor(enhanced, cv2.COLOR_LAB2BGR) | |
| # YOLO detection | |
| results = yolo_model(image) | |
| # Process detections | |
| detections = [] | |
| for result in results: | |
| for box in result.boxes: | |
| confidence = float(box.conf) | |
| class_id = int(box.cls.item()) | |
| if confidence < CONF_THRESHOLD: | |
| continue | |
| x1, y1, x2, y2 = map(int, box.xyxy.flatten()) | |
| detections.append((class_id, confidence, (x1, y1, x2, y2))) | |
| # Sort by class_id | |
| detections.sort(key=lambda x: x[0]) | |
| for class_id, confidence, (x1, y1, x2, y2) in detections: | |
| cropped_image = image[y1:y2, x1:x2] | |
| if cropped_image.size == 0: | |
| continue | |
| # Preprocess for OCR | |
| cropped_image_gray = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2GRAY) | |
| thresh_image = cv2.adaptiveThreshold( | |
| cropped_image_gray, | |
| 255, | |
| cv2.ADAPTIVE_THRESH_GAUSSIAN_C, | |
| cv2.THRESH_BINARY_INV, | |
| 11, | |
| 2 | |
| ) | |
| kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2,2)) | |
| thresh_image = cv2.morphologyEx(thresh_image, cv2.MORPH_CLOSE, kernel) | |
| cropped_image_3d = cv2.cvtColor(thresh_image, cv2.COLOR_GRAY2RGB) | |
| resized_image = cv2.resize(cropped_image_3d, (128, 32)) | |
| # OCR processing | |
| pixel_values = processor(resized_image, return_tensors="pt").pixel_values | |
| generated_ids = ocr_model.generate(pixel_values) | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| # Convert crop to PIL for display | |
| cropped_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB)) | |
| if class_id == 0: # License plate | |
| data["plate_number"] = generated_text | |
| data["plate_crop"] = cropped_pil | |
| elif class_id == 1: # Province | |
| generated_province, distance = get_closest_province(generated_text, thai_provinces) | |
| data["raw_province"] = generated_text | |
| data["province"] = generated_province | |
| data["province_crop"] = cropped_pil | |
| return data | |
| # Main app | |
| st.title("Thai License Plate Detection 🚗") | |
| # Load models | |
| try: | |
| if not st.session_state['models_loaded']: | |
| with st.spinner("Loading models... (this may take a minute)"): | |
| processor, ocr_model, yolo_model = load_models() | |
| st.session_state['models_loaded'] = True | |
| st.session_state['processor'] = processor | |
| st.session_state['ocr_model'] = ocr_model | |
| st.session_state['yolo_model'] = yolo_model | |
| except Exception as e: | |
| st.error(f"Error loading models: {str(e)}") | |
| st.stop() | |
| # File uploader | |
| uploaded_file = st.file_uploader("Upload an image of a Thai license plate", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file is not None: | |
| try: | |
| # Display the uploaded image | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.subheader("Uploaded Image") | |
| image = Image.open(uploaded_file) | |
| st.image(image, use_column_width=True) | |
| # Process the image | |
| with col2: | |
| st.subheader("Detection Results") | |
| with st.spinner("Processing image..."): | |
| results = process_image( | |
| image, | |
| st.session_state['processor'], | |
| st.session_state['ocr_model'], | |
| st.session_state['yolo_model'] | |
| ) | |
| if results["plate_number"]: | |
| st.success("Detection successful!") | |
| st.write("📝 License Plate:", results['plate_number']) | |
| if results['plate_crop'] is not None: | |
| st.subheader("Cropped License Plate") | |
| st.image(results['plate_crop'], caption="Detected License Plate Region") | |
| if results['raw_province']: | |
| st.write("🔍 Detected Province Text:", results['raw_province']) | |
| if results['province']: | |
| st.write("🏠 Matched Province:", results['province']) | |
| else: | |
| st.write("⚠️ No close province match found") | |
| if results['province_crop'] is not None: | |
| st.subheader("Cropped Province") | |
| st.image(results['province_crop'], caption="Detected Province Region") | |
| else: | |
| st.write("⚠️ No province text detected") | |
| else: | |
| st.error("No license plate detected in the image.") | |
| except Exception as e: | |
| st.error(f"An error occurred: {str(e)}") | |
| st.markdown("---") | |
| st.markdown("### Instructions") | |
| st.markdown(""" | |
| 1. Upload an image containing a Thai license plate | |
| 2. Wait for the processing to complete | |
| 3. View the detected license plate number and province | |
| """) | |
| # Add footer with GitHub link | |
| st.markdown("---") | |
| st.markdown("Made with ❤️ by [Your Name/Organization]") | |
| st.markdown("Check out the [GitHub Repository](https://github.com/yourusername/your-repo) for more information") |