Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import json | |
| import base64 | |
| import tempfile | |
| import os | |
| from typing import Dict, List, Optional, Literal | |
| from datetime import datetime | |
| from PIL import Image, ImageDraw, ImageFont | |
| import io | |
| import spaces | |
| import shutil | |
| from pathlib import Path | |
| from htrflow.volume.volume import Collection | |
| from htrflow.pipeline.pipeline import Pipeline | |
| PIPELINE_CONFIGS = { | |
| "letter_english": { | |
| "steps": [ | |
| { | |
| "step": "Segmentation", | |
| "settings": { | |
| "model": "yolo", | |
| "model_settings": {"model": "Riksarkivet/yolov9-lines-within-regions-1"}, | |
| "generation_settings": {"batch_size": 8}, | |
| }, | |
| }, | |
| { | |
| "step": "TextRecognition", | |
| "settings": { | |
| "model": "TrOCR", | |
| "model_settings": {"model": "microsoft/trocr-base-handwritten"}, | |
| "generation_settings": {"batch_size": 16}, | |
| }, | |
| }, | |
| {"step": "OrderLines"}, | |
| ] | |
| }, | |
| "letter_swedish": { | |
| "steps": [ | |
| { | |
| "step": "Segmentation", | |
| "settings": { | |
| "model": "yolo", | |
| "model_settings": {"model": "Riksarkivet/yolov9-lines-within-regions-1"}, | |
| "generation_settings": {"batch_size": 8}, | |
| }, | |
| }, | |
| { | |
| "step": "TextRecognition", | |
| "settings": { | |
| "model": "TrOCR", | |
| "model_settings": {"model": "Riksarkivet/trocr-base-handwritten-hist-swe-2"}, | |
| "generation_settings": {"batch_size": 16}, | |
| }, | |
| }, | |
| {"step": "OrderLines"}, | |
| ] | |
| }, | |
| "spread_english": { | |
| "steps": [ | |
| { | |
| "step": "Segmentation", | |
| "settings": { | |
| "model": "yolo", | |
| "model_settings": {"model": "Riksarkivet/yolov9-regions-1"}, | |
| "generation_settings": {"batch_size": 4}, | |
| }, | |
| }, | |
| { | |
| "step": "Segmentation", | |
| "settings": { | |
| "model": "yolo", | |
| "model_settings": {"model": "Riksarkivet/yolov9-lines-within-regions-1"}, | |
| "generation_settings": {"batch_size": 8}, | |
| }, | |
| }, | |
| { | |
| "step": "TextRecognition", | |
| "settings": { | |
| "model": "TrOCR", | |
| "model_settings": {"model": "microsoft/trocr-base-handwritten"}, | |
| "generation_settings": {"batch_size": 16}, | |
| }, | |
| }, | |
| {"step": "ReadingOrderMarginalia", "settings": {"two_page": True}}, | |
| ] | |
| }, | |
| "spread_swedish": { | |
| "steps": [ | |
| { | |
| "step": "Segmentation", | |
| "settings": { | |
| "model": "yolo", | |
| "model_settings": {"model": "Riksarkivet/yolov9-regions-1"}, | |
| "generation_settings": {"batch_size": 4}, | |
| }, | |
| }, | |
| { | |
| "step": "Segmentation", | |
| "settings": { | |
| "model": "yolo", | |
| "model_settings": {"model": "Riksarkivet/yolov9-lines-within-regions-1"}, | |
| "generation_settings": {"batch_size": 8}, | |
| }, | |
| }, | |
| { | |
| "step": "TextRecognition", | |
| "settings": { | |
| "model": "TrOCR", | |
| "model_settings": {"model": "Riksarkivet/trocr-base-handwritten-hist-swe-2"}, | |
| "generation_settings": {"batch_size": 16}, | |
| }, | |
| }, | |
| {"step": "ReadingOrderMarginalia", "settings": {"two_page": True}}, | |
| ] | |
| }, | |
| } | |
| def process_htr(image: Image.Image, document_type: Literal["letter_english", "letter_swedish", "spread_english", "spread_swedish"] = "letter_english", confidence_threshold: float = 0.8, custom_settings: Optional[str] = None) -> Dict: | |
| """Process handwritten text recognition on uploaded images using HTRflow pipelines.""" | |
| if image is None: | |
| return {"success": False, "error": "No image provided", "results": None} | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: | |
| image.save(temp_file.name, "PNG") | |
| temp_image_path = temp_file.name | |
| try: | |
| if custom_settings: | |
| try: | |
| config = json.loads(custom_settings) | |
| except json.JSONDecodeError: | |
| return {"success": False, "error": "Invalid JSON in custom_settings parameter", "results": None} | |
| else: | |
| config = PIPELINE_CONFIGS[document_type] | |
| collection = Collection([temp_image_path]) | |
| pipeline = Pipeline.from_config(config) | |
| try: | |
| processed_collection = pipeline.run(collection) | |
| except Exception as pipeline_error: | |
| return {"success": False, "error": f"Pipeline execution failed: {str(pipeline_error)}", "results": None} | |
| img_buffer = io.BytesIO() | |
| image.save(img_buffer, format="PNG") | |
| image_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8") | |
| results = extract_text_results(processed_collection, confidence_threshold) | |
| processing_state = { | |
| "processed_collection": processed_collection, | |
| "image_base64": image_base64, | |
| "image_size": image.size, | |
| "document_type": document_type, | |
| "confidence_threshold": confidence_threshold, | |
| "timestamp": datetime.now().isoformat(), | |
| } | |
| return { | |
| "success": True, | |
| "results": results, | |
| "processing_state": json.dumps(processing_state, default=str), | |
| "metadata": { | |
| "total_lines": len(results.get("text_lines", [])), | |
| "average_confidence": results.get("average_confidence", 0), | |
| "document_type": document_type, | |
| "image_dimensions": image.size, | |
| }, | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": f"HTR processing failed: {str(e)}", "results": None} | |
| finally: | |
| if os.path.exists(temp_image_path): | |
| os.unlink(temp_image_path) | |
| def visualize_results(processing_state: str, visualization_type: Literal["overlay", "confidence_heatmap", "text_regions"] = "overlay", show_confidence: bool = True, highlight_low_confidence: bool = True, image: Optional[Image.Image] = None) -> Dict: | |
| """Generate interactive visualizations of HTR processing results.""" | |
| try: | |
| state = json.loads(processing_state) | |
| if image is not None: | |
| original_image = image | |
| else: | |
| image_data = base64.b64decode(state["image_base64"]) | |
| original_image = Image.open(io.BytesIO(image_data)) | |
| # Recreate the collection from the stored image | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: | |
| original_image.save(temp_file.name, "PNG") | |
| temp_image_path = temp_file.name | |
| try: | |
| collection = Collection([temp_image_path]) | |
| pipeline = Pipeline.from_config(PIPELINE_CONFIGS[state["document_type"]]) | |
| processed_collection = pipeline.run(collection) | |
| viz_image = create_visualization(original_image, processed_collection, visualization_type, show_confidence, highlight_low_confidence) | |
| img_buffer = io.BytesIO() | |
| viz_image.save(img_buffer, format="PNG") | |
| img_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8") | |
| return { | |
| "success": True, | |
| "visualization": { | |
| "image_base64": img_base64, | |
| "image_format": "PNG", | |
| "visualization_type": visualization_type, | |
| "dimensions": viz_image.size, | |
| }, | |
| "metadata": {"visualization_type": visualization_type}, | |
| } | |
| finally: | |
| if os.path.exists(temp_image_path): | |
| os.unlink(temp_image_path) | |
| except Exception as e: | |
| return {"success": False, "error": f"Visualization generation failed: {str(e)}", "visualization": None} | |
| def export_results(processing_state: str, output_formats: List[Literal["txt", "json", "alto", "page"]] = ["txt"], confidence_filter: float = 0.0) -> Dict: | |
| """Export HTR results to multiple formats using HTRflow's native export functionality.""" | |
| try: | |
| state = json.loads(processing_state) | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: | |
| image_data = base64.b64decode(state["image_base64"]) | |
| image = Image.open(io.BytesIO(image_data)) | |
| image.save(temp_file.name, "PNG") | |
| temp_image_path = temp_file.name | |
| try: | |
| collection = Collection([temp_image_path]) | |
| pipeline = Pipeline.from_config(PIPELINE_CONFIGS[state["document_type"]]) | |
| processed_collection = pipeline.run(collection) | |
| temp_dir = Path(tempfile.mkdtemp()) | |
| exports = {} | |
| for fmt in output_formats: | |
| export_dir = temp_dir / fmt | |
| processed_collection.save(directory=str(export_dir), serializer=fmt) | |
| export_files = [] | |
| for root, _, files in os.walk(export_dir): | |
| for file in files: | |
| file_path = os.path.join(root, file) | |
| try: | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| export_files.append({"filename": file, "content": content}) | |
| except UnicodeDecodeError: | |
| with open(file_path, 'rb') as f: | |
| content = base64.b64encode(f.read()).decode('utf-8') | |
| export_files.append({"filename": file, "content": content, "encoding": "base64"}) | |
| exports[fmt] = export_files | |
| shutil.rmtree(temp_dir) | |
| return { | |
| "success": True, | |
| "exports": exports, | |
| "export_metadata": { | |
| "formats_generated": output_formats, | |
| "confidence_filter": confidence_filter, | |
| "timestamp": datetime.now().isoformat(), | |
| }, | |
| } | |
| finally: | |
| if os.path.exists(temp_image_path): | |
| os.unlink(temp_image_path) | |
| except Exception as e: | |
| return {"success": False, "error": f"Export generation failed: {str(e)}", "exports": None} | |
| def extract_text_results(collection: Collection, confidence_threshold: float) -> Dict: | |
| results = {"extracted_text": "", "text_lines": [], "confidence_scores": []} | |
| for page in collection.pages: | |
| for node in page.traverse(): | |
| if hasattr(node, "text") and node.text and hasattr(node, "confidence") and node.confidence >= confidence_threshold: | |
| results["text_lines"].append({ | |
| "text": node.text, | |
| "confidence": node.confidence, | |
| "bbox": getattr(node, "bbox", None), | |
| }) | |
| results["extracted_text"] += node.text + "\n" | |
| results["confidence_scores"].append(node.confidence) | |
| results["average_confidence"] = sum(results["confidence_scores"]) / len(results["confidence_scores"]) if results["confidence_scores"] else 0 | |
| return results | |
| def create_visualization(image, collection, visualization_type, show_confidence, highlight_low_confidence): | |
| viz_image = image.copy() | |
| draw = ImageDraw.Draw(viz_image) | |
| try: | |
| font = ImageFont.truetype("arial.ttf", 12) | |
| except: | |
| font = ImageFont.load_default() | |
| for page in collection.pages: | |
| for node in page.traverse(): | |
| if hasattr(node, "bbox") and hasattr(node, "text") and node.bbox and node.text: | |
| bbox = node.bbox | |
| confidence = getattr(node, "confidence", 1.0) | |
| if visualization_type == "overlay": | |
| color = (255, 165, 0) if highlight_low_confidence and confidence < 0.7 else (0, 255, 0) | |
| draw.rectangle(bbox, outline=color, width=2) | |
| if show_confidence: | |
| draw.text((bbox[0], bbox[1] - 15), f"{confidence:.2f}", fill=color, font=font) | |
| elif visualization_type == "confidence_heatmap": | |
| if confidence < 0.5: | |
| color = (255, 0, 0, 100) | |
| elif confidence < 0.8: | |
| color = (255, 255, 0, 100) | |
| else: | |
| color = (0, 255, 0, 100) | |
| overlay = Image.new("RGBA", viz_image.size, (0, 0, 0, 0)) | |
| overlay_draw = ImageDraw.Draw(overlay) | |
| overlay_draw.rectangle(bbox, fill=color) | |
| viz_image = Image.alpha_composite(viz_image.convert("RGBA"), overlay) | |
| elif visualization_type == "text_regions": | |
| colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)] | |
| color = colors[hash(str(bbox)) % len(colors)] | |
| draw.rectangle(bbox, outline=color, width=3) | |
| return viz_image.convert("RGB") if visualization_type == "confidence_heatmap" else viz_image | |
| def create_htrflow_mcp_server(): | |
| demo = gr.TabbedInterface( | |
| [ | |
| gr.Interface( | |
| fn=process_htr, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload Image"), | |
| gr.Dropdown(choices=["letter_english", "letter_swedish", "spread_english", "spread_swedish"], value="letter_english", label="Document Type"), | |
| gr.Slider(0.0, 1.0, value=0.8, label="Confidence Threshold"), | |
| gr.Textbox(label="Custom Settings (JSON)", placeholder="Optional custom pipeline settings"), | |
| ], | |
| outputs=gr.JSON(label="Processing Results"), | |
| title="HTR Processing Tool", | |
| description="Process handwritten text using configurable HTRflow pipelines", | |
| api_name="process_htr", | |
| ), | |
| gr.Interface( | |
| fn=visualize_results, | |
| inputs=[ | |
| gr.Textbox(label="Processing State (JSON)", placeholder="Paste processing results from HTR tool"), | |
| gr.Dropdown(choices=["overlay", "confidence_heatmap", "text_regions"], value="overlay", label="Visualization Type"), | |
| gr.Checkbox(value=True, label="Show Confidence Scores"), | |
| gr.Checkbox(value=True, label="Highlight Low Confidence"), | |
| gr.Image(type="pil", label="Image (optional)"), | |
| ], | |
| outputs=gr.JSON(label="Visualization Results"), | |
| title="Results Visualization Tool", | |
| description="Generate interactive visualizations of HTR results", | |
| api_name="visualize_results", | |
| ), | |
| gr.Interface( | |
| fn=export_results, | |
| inputs=[ | |
| gr.Textbox(label="Processing State (JSON)", placeholder="Paste processing results from HTR tool"), | |
| gr.CheckboxGroup(choices=["txt", "json", "alto", "page"], value=["txt"], label="Output Formats"), | |
| gr.Slider(0.0, 1.0, value=0.0, label="Confidence Filter"), | |
| ], | |
| outputs=gr.JSON(label="Export Results"), | |
| title="Export Tool", | |
| description="Export HTR results to multiple formats", | |
| api_name="export_results", | |
| ), | |
| ], | |
| ["HTR Processing", "Results Visualization", "Export Results"], | |
| title="HTRflow MCP Server", | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_htrflow_mcp_server() | |
| demo.launch(mcp_server=True) |