Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import json | |
| import tempfile | |
| import os | |
| import zipfile | |
| import shutil | |
| from typing import List, Optional, Literal, Union, Dict | |
| from PIL import Image | |
| import requests | |
| from pathlib import Path | |
| import spaces | |
| from visualizer import htrflow_visualizer | |
| from htrflow.volume.volume import Collection | |
| from htrflow.pipeline.pipeline import Pipeline | |
| DEFAULT_OUTPUT = "alto" | |
| FORMAT_CHOICES = [ | |
| "letter_english", | |
| "letter_swedish", | |
| "spread_english", | |
| "spread_swedish", | |
| ] | |
| FILE_CHOICES = ["txt", "alto", "page", "json"] | |
| FormatChoices = Literal[ | |
| "letter_english", "letter_swedish", "spread_english", "spread_swedish" | |
| ] | |
| FileChoices = Literal["txt", "alto", "page", "json"] | |
| PIPELINE_CONFIGS = { | |
| "letter_english": { | |
| "steps": [ | |
| { | |
| "step": "Segmentation", | |
| "settings": { | |
| "model": "yolo", | |
| "model_settings": { | |
| "model": "Riksarkivet/yolov9-lines-within-regions-1" | |
| }, | |
| "generation_settings": {"batch_size": 8}, | |
| }, | |
| }, | |
| { | |
| "step": "TextRecognition", | |
| "settings": { | |
| "model": "TrOCR", | |
| "model_settings": {"model": "microsoft/trocr-base-handwritten"}, | |
| "generation_settings": {"batch_size": 16}, | |
| }, | |
| }, | |
| {"step": "OrderLines"}, | |
| ] | |
| }, | |
| "letter_swedish": { | |
| "steps": [ | |
| { | |
| "step": "Segmentation", | |
| "settings": { | |
| "model": "yolo", | |
| "model_settings": { | |
| "model": "Riksarkivet/yolov9-lines-within-regions-1" | |
| }, | |
| "generation_settings": {"batch_size": 8}, | |
| }, | |
| }, | |
| { | |
| "step": "TextRecognition", | |
| "settings": { | |
| "model": "TrOCR", | |
| "model_settings": { | |
| "model": "Riksarkivet/trocr-base-handwritten-hist-swe-2" | |
| }, | |
| "generation_settings": {"batch_size": 16}, | |
| }, | |
| }, | |
| {"step": "OrderLines"}, | |
| ] | |
| }, | |
| "spread_english": { | |
| "steps": [ | |
| { | |
| "step": "Segmentation", | |
| "settings": { | |
| "model": "yolo", | |
| "model_settings": {"model": "Riksarkivet/yolov9-regions-1"}, | |
| "generation_settings": {"batch_size": 4}, | |
| }, | |
| }, | |
| { | |
| "step": "Segmentation", | |
| "settings": { | |
| "model": "yolo", | |
| "model_settings": { | |
| "model": "Riksarkivet/yolov9-lines-within-regions-1" | |
| }, | |
| "generation_settings": {"batch_size": 8}, | |
| }, | |
| }, | |
| { | |
| "step": "TextRecognition", | |
| "settings": { | |
| "model": "TrOCR", | |
| "model_settings": {"model": "microsoft/trocr-base-handwritten"}, | |
| "generation_settings": {"batch_size": 16}, | |
| }, | |
| }, | |
| {"step": "ReadingOrderMarginalia", "settings": {"two_page": True}}, | |
| ] | |
| }, | |
| "spread_swedish": { | |
| "steps": [ | |
| { | |
| "step": "Segmentation", | |
| "settings": { | |
| "model": "yolo", | |
| "model_settings": {"model": "Riksarkivet/yolov9-regions-1"}, | |
| "generation_settings": {"batch_size": 4}, | |
| }, | |
| }, | |
| { | |
| "step": "Segmentation", | |
| "settings": { | |
| "model": "yolo", | |
| "model_settings": { | |
| "model": "Riksarkivet/yolov9-lines-within-regions-1" | |
| }, | |
| "generation_settings": {"batch_size": 8}, | |
| }, | |
| }, | |
| { | |
| "step": "TextRecognition", | |
| "settings": { | |
| "model": "TrOCR", | |
| "model_settings": { | |
| "model": "Riksarkivet/trocr-base-handwritten-hist-swe-2" | |
| }, | |
| "generation_settings": {"batch_size": 16}, | |
| }, | |
| }, | |
| {"step": "ReadingOrderMarginalia", "settings": {"two_page": True}}, | |
| ] | |
| }, | |
| } | |
| def handle_image_input(image_path: Union[str, None], progress: gr.Progress = None, desc_prefix: str = "") -> str: | |
| """ | |
| Handle image input from various sources (local file, URL, or uploaded file). | |
| Args: | |
| image_path: Path to image file or URL | |
| progress: Progress tracker for UI updates | |
| desc_prefix: Prefix for progress descriptions | |
| Returns: | |
| Local file path to the image | |
| """ | |
| if not image_path: | |
| raise ValueError("No image provided. Please upload an image or provide a URL.") | |
| if progress: | |
| progress(0.1, desc=f"{desc_prefix}Processing image input...") | |
| # If it's a URL, download the image | |
| if isinstance(image_path, str) and (image_path.startswith("http://") or image_path.startswith("https://")): | |
| try: | |
| if progress: | |
| progress(0.2, desc=f"{desc_prefix}Downloading image from URL...") | |
| response = requests.get(image_path, timeout=30) | |
| response.raise_for_status() | |
| # Save to temporary file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp_file: | |
| tmp_file.write(response.content) | |
| image_path = tmp_file.name | |
| # Verify it's a valid image | |
| try: | |
| img = Image.open(image_path) | |
| img.verify() | |
| except Exception as e: | |
| os.unlink(image_path) | |
| raise ValueError(f"Downloaded file is not a valid image: {str(e)}") | |
| except requests.RequestException as e: | |
| raise ValueError(f"Failed to download image from URL: {str(e)}") | |
| # Verify the file exists | |
| if not os.path.exists(image_path): | |
| raise ValueError(f"Image file not found: {image_path}") | |
| return image_path | |
| def parse_image_input(image_input: Union[str, List[str], None]) -> List[str]: | |
| """ | |
| Parse image input which can be a single path, multiple paths, or URLs separated by newlines. | |
| Args: | |
| image_input: Single image path, list of paths, or newline-separated URLs/paths | |
| Returns: | |
| List of image paths/URLs | |
| """ | |
| if not image_input: | |
| return [] | |
| if isinstance(image_input, list): | |
| return image_input | |
| if isinstance(image_input, str): | |
| # Check if it's multiple URLs/paths separated by newlines | |
| lines = image_input.strip().split('\n') | |
| paths = [] | |
| for line in lines: | |
| line = line.strip() | |
| if line: # Skip empty lines | |
| paths.append(line) | |
| return paths if paths else [image_input] | |
| return [] | |
| def _process_htr_pipeline_batch( | |
| image_paths: List[str], | |
| document_type: FormatChoices, | |
| custom_settings: Optional[str] = None, | |
| progress: gr.Progress = None | |
| ) -> Dict[str, Collection]: | |
| """Process HTR pipeline for multiple images and return processed collections.""" | |
| results = {} | |
| temp_files = [] | |
| total_images = len(image_paths) | |
| if custom_settings: | |
| try: | |
| config = json.loads(custom_settings) | |
| except json.JSONDecodeError: | |
| raise ValueError("Invalid JSON in custom_settings parameter. Please check your JSON syntax.") | |
| else: | |
| config = PIPELINE_CONFIGS[document_type] | |
| # Initialize pipeline once for all images | |
| pipeline = Pipeline.from_config(config) | |
| for idx, image_path in enumerate(image_paths): | |
| try: | |
| image_name = Path(image_path).stem if not image_path.startswith("http") else f"image_{idx+1}" | |
| if progress: | |
| progress((idx + 0.2) / total_images, | |
| desc=f"Processing image {idx+1}/{total_images}: {image_name}") | |
| # Handle image input | |
| processed_path = handle_image_input(image_path, progress, | |
| desc_prefix=f"[{idx+1}/{total_images}] ") | |
| # Track temp files for cleanup | |
| if processed_path.startswith(tempfile.gettempdir()): | |
| temp_files.append(processed_path) | |
| if progress: | |
| progress((idx + 0.5) / total_images, | |
| desc=f"Running HTR on image {idx+1}/{total_images}: {image_name}") | |
| # Process with pipeline | |
| collection = Collection([processed_path]) | |
| processed_collection = pipeline.run(collection) | |
| results[image_name] = processed_collection | |
| if progress: | |
| progress((idx + 1.0) / total_images, | |
| desc=f"Completed image {idx+1}/{total_images}: {image_name}") | |
| except Exception as e: | |
| results[image_name] = f"Error: {str(e)}" | |
| print(f"Error processing {image_path}: {str(e)}") | |
| # Cleanup temp files | |
| for temp_file in temp_files: | |
| try: | |
| os.unlink(temp_file) | |
| except: | |
| pass | |
| if progress: | |
| progress(1.0, desc=f"Completed processing all {total_images} images!") | |
| return results | |
| def extract_text_from_collection(collection: Collection) -> str: | |
| """Extract and combine text from all nodes in the collection.""" | |
| text_lines = [] | |
| for page in collection.pages: | |
| for node in page.traverse(): | |
| if hasattr(node, "text") and node.text: | |
| text_lines.append(node.text) | |
| return "\n".join(text_lines) | |
| def htr_text( | |
| image_input: Union[str, List[str]], | |
| document_type: FormatChoices = "letter_swedish", | |
| custom_settings: Optional[str] = None, | |
| return_format: str = "separate", # "separate" or "combined" | |
| progress: gr.Progress = gr.Progress() | |
| ) -> str: | |
| """ | |
| Extract text from handwritten documents using HTR. | |
| Handles both single images and multiple images. | |
| Args: | |
| image_input: Single image path/URL, multiple paths/URLs (newline-separated), or list of uploaded files | |
| document_type: Type of document layout - choose based on your documents' structure and language | |
| custom_settings: Optional JSON configuration for advanced pipeline customization | |
| return_format: "separate" to show each document's text separately, "combined" to merge all text | |
| progress: Progress tracker for UI updates | |
| Returns: | |
| Extracted text from all handwritten documents | |
| """ | |
| try: | |
| if progress: | |
| progress(0, desc="Starting HTR text extraction...") | |
| # Parse input to get list of images | |
| image_paths = parse_image_input(image_input) | |
| if not image_paths: | |
| return "No images provided. Please upload images or provide URLs." | |
| # Adjust description based on single vs multiple | |
| num_images = len(image_paths) | |
| desc = f"Processing {num_images} image{'s' if num_images > 1 else ''}..." | |
| if progress: | |
| progress(0.1, desc=desc) | |
| # Process all images | |
| results = _process_htr_pipeline_batch( | |
| image_paths, document_type, custom_settings, progress | |
| ) | |
| # Extract text from results | |
| all_texts = [] | |
| for image_name, collection in results.items(): | |
| if isinstance(collection, str): # Error case | |
| all_texts.append(f"=== {image_name} ===\n{collection}\n") | |
| else: | |
| text = extract_text_from_collection(collection) | |
| if return_format == "separate": | |
| all_texts.append(f"=== {image_name} ===\n{text}\n") | |
| else: | |
| all_texts.append(text) | |
| # Return formatted result | |
| if return_format == "separate": | |
| return "\n".join(all_texts) | |
| else: | |
| return "\n\n".join(all_texts) | |
| except ValueError as e: | |
| return f"Input error: {str(e)}" | |
| except Exception as e: | |
| return f"HTR text extraction failed: {str(e)}" | |
| def htr_generate_files( | |
| image_input: Union[str, List[str]], | |
| document_type: FormatChoices = "letter_swedish", | |
| output_format: FileChoices = DEFAULT_OUTPUT, | |
| custom_settings: Optional[str] = None, | |
| progress: gr.Progress = gr.Progress() | |
| ) -> str: | |
| """ | |
| Process handwritten documents and generate formatted output files. | |
| Returns a ZIP file for multiple documents, or single file for single document. | |
| Args: | |
| image_input: Single image path/URL, multiple paths/URLs (newline-separated), or list of uploaded files | |
| document_type: Type of document layout - affects segmentation and reading order | |
| output_format: Desired output format (txt for plain text, alto/page for XML with coordinates, json for structured data) | |
| custom_settings: Optional JSON configuration for advanced pipeline customization | |
| progress: Progress tracker for UI updates | |
| Returns: | |
| Path to generated file(s) | |
| """ | |
| try: | |
| if progress: | |
| progress(0, desc="Starting HTR file processing...") | |
| # Parse input to get list of images | |
| image_paths = parse_image_input(image_input) | |
| if not image_paths: | |
| error_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') | |
| error_file.write("No images provided. Please upload images or provide URLs.") | |
| error_file.close() | |
| return error_file.name | |
| num_images = len(image_paths) | |
| if progress: | |
| progress(0.1, desc=f"Processing {num_images} image{'s' if num_images > 1 else ''}...") | |
| # Process all images | |
| results = _process_htr_pipeline_batch( | |
| image_paths, document_type, custom_settings, progress | |
| ) | |
| if progress: | |
| progress(0.9, desc="Creating output files...") | |
| # Create temporary directory for output files | |
| temp_dir = Path(tempfile.mkdtemp()) | |
| output_files = [] | |
| for image_name, collection in results.items(): | |
| if isinstance(collection, str): # Error case | |
| # Write error to text file | |
| error_file_path = temp_dir / f"{image_name}_error.txt" | |
| with open(error_file_path, 'w') as f: | |
| f.write(collection) | |
| output_files.append(error_file_path) | |
| else: | |
| # Save collection in requested format | |
| export_dir = temp_dir / image_name | |
| collection.save(directory=str(export_dir), serializer=output_format) | |
| # Find and rename the generated file | |
| for root, _, files in os.walk(export_dir): | |
| for file in files: | |
| old_path = Path(root) / file | |
| file_ext = Path(file).suffix | |
| new_filename = ( | |
| f"{image_name}.{output_format}" | |
| if not file_ext | |
| else f"{image_name}{file_ext}" | |
| ) | |
| new_path = temp_dir / new_filename | |
| shutil.move(str(old_path), str(new_path)) | |
| output_files.append(new_path) | |
| break | |
| # Return single file or ZIP based on input count | |
| if len(output_files) == 1 and len(image_paths) == 1: | |
| # Single file - return directly | |
| if progress: | |
| progress(1.0, desc="Processing complete!") | |
| return str(output_files[0]) | |
| else: | |
| # Multiple files - create ZIP | |
| zip_path = temp_dir / f"htr_output_{output_format}.zip" | |
| with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: | |
| for file_path in output_files: | |
| zipf.write(file_path, file_path.name) | |
| if progress: | |
| progress(1.0, desc=f"Processing complete! Generated {len(output_files)} files.") | |
| return str(zip_path) | |
| except ValueError as e: | |
| error_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') | |
| error_file.write(f"Input error: {str(e)}") | |
| error_file.close() | |
| return error_file.name | |
| except Exception as e: | |
| error_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') | |
| error_file.write(f"HTR file generation failed: {str(e)}") | |
| error_file.close() | |
| return error_file.name | |
| def htr_visualize( | |
| image_input: Union[str, List[str]], | |
| htr_documents: Union[List[str], None], | |
| progress: gr.Progress = gr.Progress() | |
| ) -> str: | |
| """ | |
| Create visualizations for HTR results overlaid on original documents. | |
| Returns a ZIP file for multiple documents, or single image for single document. | |
| Args: | |
| image_input: Original document image paths/URLs (newline-separated if string) | |
| htr_documents: HTR output files (ALTO/PAGE XML) - must match order of images | |
| progress: Progress tracker for UI updates | |
| Returns: | |
| Path to visualization file(s) | |
| """ | |
| try: | |
| if progress: | |
| progress(0, desc="Starting visualization generation...") | |
| # Parse inputs | |
| image_paths = parse_image_input(image_input) | |
| # Handle htr_documents - it should be a list of file paths | |
| if not htr_documents: | |
| raise ValueError("No HTR documents provided") | |
| # If htr_documents is a list of file objects from Gradio, extract paths | |
| htr_paths = [] | |
| if isinstance(htr_documents, list): | |
| for doc in htr_documents: | |
| if isinstance(doc, str): | |
| htr_paths.append(doc) | |
| elif hasattr(doc, 'name'): | |
| htr_paths.append(doc.name) | |
| else: | |
| htr_paths.append(str(doc)) | |
| else: | |
| # Single file case | |
| if isinstance(htr_documents, str): | |
| htr_paths = [htr_documents] | |
| elif hasattr(htr_documents, 'name'): | |
| htr_paths = [htr_documents.name] | |
| else: | |
| htr_paths = [str(htr_documents)] | |
| if not image_paths: | |
| raise ValueError("No images provided") | |
| if len(image_paths) != len(htr_paths): | |
| raise ValueError(f"Number of images ({len(image_paths)}) doesn't match number of HTR documents ({len(htr_paths)})") | |
| num_docs = len(image_paths) | |
| if progress: | |
| progress(0.1, desc=f"Creating visualization{'s' if num_docs > 1 else ''} for {num_docs} document{'s' if num_docs > 1 else ''}...") | |
| temp_dir = Path(tempfile.mkdtemp()) | |
| output_files = [] | |
| temp_files = [] | |
| for idx, (image_path, htr_path) in enumerate(zip(image_paths, htr_paths)): | |
| try: | |
| image_name = Path(image_path).stem if not image_path.startswith("http") else f"image_{idx+1}" | |
| if progress: | |
| progress((idx + 0.3) / num_docs, | |
| desc=f"Visualizing document {idx+1}/{num_docs}: {image_name}") | |
| # Handle image input | |
| processed_image = handle_image_input(image_path, progress, | |
| desc_prefix=f"[{idx+1}/{num_docs}] ") | |
| if processed_image.startswith(tempfile.gettempdir()): | |
| temp_files.append(processed_image) | |
| # Generate visualization - use the last parameter for output path | |
| output_viz_path = str(temp_dir / f"{image_name}_visualization.png") | |
| viz_result = htrflow_visualizer(processed_image, htr_path, output_viz_path) | |
| # Check if visualization was created | |
| if os.path.exists(output_viz_path): | |
| output_files.append(Path(output_viz_path)) | |
| elif viz_result and os.path.exists(viz_result): | |
| # Fallback: if viz_result points to a different file | |
| viz_path = temp_dir / f"{image_name}_visualization.png" | |
| shutil.move(viz_result, str(viz_path)) | |
| output_files.append(viz_path) | |
| else: | |
| raise ValueError("Visualization generation failed - no output file created") | |
| except Exception as e: | |
| # Create error file for this visualization | |
| error_path = temp_dir / f"{image_name}_viz_error.txt" | |
| with open(error_path, 'w') as f: | |
| f.write(f"Visualization failed: {str(e)}") | |
| output_files.append(error_path) | |
| print(f"Error visualizing {image_name}: {str(e)}") | |
| # Cleanup temp files | |
| for temp_file in temp_files: | |
| try: | |
| os.unlink(temp_file) | |
| except: | |
| pass | |
| # Return single file or ZIP based on input count | |
| if len(output_files) == 1 and num_docs == 1: | |
| # Single visualization - return directly | |
| if progress: | |
| progress(1.0, desc="Visualization complete!") | |
| return str(output_files[0]) | |
| else: | |
| # Multiple visualizations - create ZIP | |
| if progress: | |
| progress(0.9, desc="Creating ZIP archive...") | |
| zip_path = temp_dir / "htr_visualizations.zip" | |
| with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: | |
| for file_path in output_files: | |
| zipf.write(file_path, file_path.name) | |
| if progress: | |
| progress(1.0, desc=f"Visualization complete! Created {len(output_files)} visualization{'s' if len(output_files) > 1 else ''}.") | |
| return str(zip_path) | |
| except Exception as e: | |
| error_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') | |
| error_file.write(f"Visualization failed: {str(e)}") | |
| error_file.close() | |
| return error_file.name | |
| def create_htrflow_mcp_server(): | |
| # HTR Text extraction interface | |
| htr_text_interface = gr.Interface( | |
| fn=htr_text, | |
| inputs=[ | |
| gr.Textbox( | |
| label="Image Input", | |
| placeholder="Single image path/URL or multiple (one per line)\nYou can also drag and drop files here", | |
| lines=3 | |
| ), | |
| gr.Dropdown( | |
| choices=FORMAT_CHOICES, | |
| value="letter_swedish", | |
| label="Document Type", | |
| info="Select the type that best matches your documents' layout and language" | |
| ), | |
| gr.Textbox( | |
| label="Custom Settings (JSON)", | |
| placeholder='{"steps": [...]} - Leave empty for default settings', | |
| value="", | |
| lines=3 | |
| ), | |
| gr.Radio( | |
| choices=["separate", "combined"], | |
| value="separate", | |
| label="Output Format", | |
| info="'separate' shows each document's text with headers, 'combined' merges all text" | |
| ), | |
| ], | |
| outputs=[gr.Textbox(label="Extracted Text", lines=20)], | |
| title="Extract Text from Handwritten Documents", | |
| description="Process one or more handwritten document images. Works with letters and book spreads in English and Swedish.", | |
| api_name="htr_text", | |
| ) | |
| # HTR File generation interface | |
| htr_files_interface = gr.Interface( | |
| fn=htr_generate_files, | |
| inputs=[ | |
| gr.Textbox( | |
| label="Image Input", | |
| placeholder="Single image path/URL or multiple (one per line)\nYou can also drag and drop files here", | |
| lines=3 | |
| ), | |
| gr.Dropdown( | |
| choices=FORMAT_CHOICES, | |
| value="letter_swedish", | |
| label="Document Type", | |
| info="Select the type that best matches your documents' layout and language" | |
| ), | |
| gr.Dropdown( | |
| choices=FILE_CHOICES, | |
| value=DEFAULT_OUTPUT, | |
| label="Output Format", | |
| info="ALTO/PAGE: XML with coordinates | JSON: Structured data | TXT: Plain text only" | |
| ), | |
| gr.Textbox( | |
| label="Custom Settings (JSON)", | |
| placeholder='{"steps": [...]} - Leave empty for default settings', | |
| value="", | |
| lines=3 | |
| ), | |
| ], | |
| outputs=[gr.File(label="Download HTR Output")], | |
| title="Generate HTR Output Files", | |
| description="Process handwritten documents and export in various formats. Returns ZIP for multiple files.", | |
| api_name="htr_generate_files", | |
| ) | |
| # HTR Visualization interface | |
| htr_viz_interface = gr.Interface( | |
| fn=htr_visualize, | |
| inputs=[ | |
| gr.Textbox( | |
| label="Original Image Paths/URLs", | |
| placeholder="One path/URL per line", | |
| lines=3 | |
| ), | |
| gr.File( | |
| label="Upload HTR XML Files (ALTO/PAGE)", | |
| file_types=[".xml"], | |
| file_count="multiple" | |
| ), | |
| ], | |
| outputs=gr.File(label="Download Visualization"), | |
| title="Visualize HTR Results", | |
| description="Create annotated images showing detected regions and text. Files must be in matching order.", | |
| api_name="htr_visualize", | |
| ) | |
| # Create tabbed interface | |
| demo = gr.TabbedInterface( | |
| [ | |
| htr_text_interface, | |
| htr_files_interface, | |
| htr_viz_interface, | |
| ], | |
| [ | |
| "📚 Extract Text", | |
| "📁 Generate Files", | |
| "🖼️ Visualize Results", | |
| ], | |
| title="🖋️ HTRflow - Handwritten Text Recognition", | |
| analytics_enabled=False, | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_htrflow_mcp_server() | |
| demo.launch( | |
| mcp_server=True, | |
| share=False, | |
| debug=False, | |
| show_api=True, | |
| favicon_path=None, | |
| ) |