Spaces:
Sleeping
Sleeping
| from datasets import load_dataset | |
| from functools import partial | |
| from pandas import DataFrame | |
| import earthview as ev | |
| import utils | |
| import gradio as gr | |
| import tqdm | |
| import os | |
| import numpy as np | |
| # Set DEBUG to False for normal operation, "random" for random data, "samples" for local parquet samples | |
| DEBUG = False | |
| app_state = { | |
| "dsi": None, # Dataset iterator | |
| "subset": None, # Currently loaded subset | |
| } | |
| def open_dataset(dataset, subset, split, batch_size, shard_value, only_rgb): | |
| """ | |
| Loads the specified dataset subset and shard, initializes the iterator, | |
| and returns initial images and metadata. | |
| Args: | |
| dataset (str): Name of the main dataset. | |
| subset (str): Name of the subset to load. | |
| split (str): Data split (e.g., 'train', 'test'). | |
| batch_size (int): Number of items to fetch per batch. | |
| shard_value (int): The specific shard index to load (-1 for all). | |
| only_rgb (bool): Whether to load only RGB images. | |
| Returns: | |
| tuple: Updated components/values for the Gradio interface: | |
| (updated_shard_slider, initial_gallery_images, initial_metadata_table). | |
| """ | |
| global app_state | |
| print(f"Loading dataset: {dataset}, subset: {subset}, split: {split}, shard: {shard_value}") | |
| try: | |
| nshards = ev.get_nshards(subset) # Get total number of shards for the subset | |
| except Exception as e: | |
| raise gr.Error(f"Failed to get shard count for subset '{subset}': {e}") | |
| # Determine which shards to load | |
| if shard_value == -1: | |
| shards_to_load = None # Load all shards | |
| print("Loading all shards.") | |
| else: | |
| # Ensure the selected shard is within the valid range | |
| shard_value = max(0, min(shard_value, nshards - 1)) | |
| shards_to_load = [shard_value] | |
| print(f"Loading shard {shard_value} out of {nshards}.") | |
| # Load the dataset based on DEBUG configuration | |
| ds = None | |
| if DEBUG == "random": | |
| print("DEBUG MODE: Using random data.") | |
| ds = range(batch_size * 2) # Generate enough for a couple of batches | |
| elif DEBUG == "samples": | |
| print("DEBUG MODE: Using local Parquet samples.") | |
| try: | |
| ds = ev.load_parquet(subset, batch_size=batch_size * 2) | |
| except Exception as e: | |
| raise gr.Error(f"Failed to load Parquet samples for '{subset}': {e}") | |
| elif not DEBUG: | |
| print("Loading dataset from source...") | |
| try: | |
| ds = ev.load_dataset(subset, dataset=dataset, split=split, shards=shards_to_load, cache_dir="dataset") | |
| except Exception as e: | |
| raise gr.Error(f"Failed to load dataset '{dataset}/{subset}': {e}") | |
| else: | |
| raise ValueError("Invalid DEBUG setting.") | |
| # Create an iterator and store it in the state | |
| app_state["dsi"] = iter(ds) | |
| app_state["subset"] = subset | |
| print("Dataset loaded, fetching initial batch...") | |
| images, metadata_df = get_images(batch_size, only_rgb) | |
| updated_shard_slider = gr.Slider(label=f"Shard (0 to {nshards-1})", value=shard_value, maximum=nshards -1 if nshards > 0 else 0) | |
| return updated_shard_slider, images, metadata_df | |
| def get_images(batch_size, only_rgb): | |
| """ | |
| Fetches the next batch of images and metadata from the current dataset iterator. | |
| Args: | |
| batch_size (int): Number of items to fetch. | |
| only_rgb (bool): Whether to load only RGB images. | |
| Returns: | |
| tuple: (list_of_images, pandas_dataframe_of_metadata) | |
| """ | |
| global app_state | |
| if app_state.get("dsi") is None or app_state.get("subset") is None: | |
| raise gr.Error("You need to load a Dataset first using the 'Load' button.") | |
| subset = app_state["subset"] | |
| dsi = app_state["dsi"] | |
| images = [] | |
| metadatas = [] | |
| print(f"Fetching next {batch_size} images...") | |
| for i in tqdm.trange(batch_size, desc=f"Getting images for {subset}"): | |
| if DEBUG == "random": | |
| # Generate random image and basic metadata for debugging | |
| img_rgb = np.random.randint(0, 255, (384, 384, 3), dtype=np.uint8) | |
| images.append(img_rgb) | |
| if not only_rgb: | |
| img_other = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) | |
| images.append(img_other) | |
| metadatas.append({"id": f"random_{i}", "bounds": [[1, 1, 4, 4]], "map": "N/A"}) | |
| else: | |
| try: | |
| # Get the next item from the iterator | |
| item = next(dsi) | |
| except StopIteration: | |
| print("End of dataset iterator reached.") | |
| gr.Warning("End of dataset shard reached.") # Inform user | |
| break # Stop fetching if iterator is exhausted | |
| try: | |
| # Process the item to extract images and metadata | |
| item_data = ev.item_to_images(subset, item) | |
| metadata = item_data["metadata"] | |
| # Append images based on subset type and only_rgb flag | |
| if subset == "satellogic": | |
| images.extend(item_data.get("rgb", [])) | |
| if not only_rgb: | |
| images.extend(item_data.get("1m", [])) | |
| elif subset == "sentinel_1": | |
| images.extend(item_data.get("10m", [])) | |
| elif subset == "sentinel_2": | |
| images.extend(item_data.get("rgb", [])) | |
| if not only_rgb: | |
| images.extend(item_data.get("10m", [])) | |
| images.extend(item_data.get("20m", [])) | |
| images.extend(item_data.get("scl", [])) | |
| elif subset == "neon": | |
| images.extend(item_data.get("rgb", [])) | |
| if not only_rgb: | |
| images.extend(item_data.get("chm", [])) | |
| images.extend(item_data.get("1m", [])) | |
| else: | |
| # Handle potential unknown subsets gracefully | |
| print(f"Warning: Image extraction logic not defined for subset '{subset}'. Trying 'rgb'.") | |
| images.extend(item_data.get("rgb", [])) | |
| map_link = utils.get_google_map_link(item_data, subset) | |
| metadata["map"] = f'<a href="{map_link}" target="_blank">🧭 View Map</a>' if map_link else "N/A" | |
| metadatas.append(metadata) | |
| except Exception as e: | |
| print(f"Error processing item: {item}. Error: {e}") | |
| metadatas.append({"id": item.get("id", "Error"), "error": str(e), "map": "Error"}) | |
| print(f"Fetched {len(metadatas)} items for the batch.") | |
| # Convert metadata list to a Pandas DataFrame | |
| metadata_df = DataFrame(metadatas) | |
| return images, metadata_df | |
| def update_gallery_columns(columns): | |
| """ | |
| Updates the number of columns in the image gallery. | |
| Args: | |
| columns (int): The desired number of columns. | |
| Returns: | |
| dict: A dictionary mapping the gallery component to its updated state. | |
| In Gradio 5, we return the component constructor with new args. | |
| """ | |
| print(f"Updating gallery columns to: {columns}") | |
| # Ensure columns is at least 1 | |
| columns = max(1, int(columns)) | |
| # Return the updated component configuration | |
| return gr.Gallery(columns=columns) | |
| if __name__ == "__main__": | |
| with gr.Blocks(title="EarthView Viewer v5 fork", fill_height=True, theme=gr.themes.Default()) as demo: | |
| gr.Markdown(f"# Viewer for [{ev.DATASET}](https://huggingface.co/datasets/satellogic/EarthView) Dataset (Gradio 5)") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| dataset_name = gr.Textbox(label="Dataset", value=ev.DATASET, interactive=False) | |
| subset_select = gr.Dropdown(choices=ev.get_subsets(), label="Subset", value="satellogic") | |
| split_name = gr.Textbox(label="Split", value="train") | |
| initial_shard_input = gr.Number(label="Load Shard", value=10, minimum=-1, step=1, info="Enter shard index (0-based) or -1 for all shards") | |
| only_rgb_checkbox = gr.Checkbox(label="Only RGB Images", value=True) | |
| batch_size_input = gr.Number(value=10, label="Batch Size", minimum=1, step=1) | |
| load_button = gr.Button("Load Dataset / Shard", variant="primary") | |
| shard_slider = gr.Slider(label="Shard", minimum=0, maximum=1, step=1, value=0) | |
| gallery_columns_input = gr.Number(value=5, label="Gallery Columns", minimum=1, step=1) | |
| next_batch_button = gr.Button("Next Batch (from current shard)", scale=0) | |
| with gr.Column(scale=4): | |
| image_gallery = gr.Gallery( | |
| label="Dataset Images", | |
| interactive=False, | |
| object_fit="scale-down", | |
| columns=5, | |
| height="600px", | |
| show_label=False | |
| ) | |
| metadata_table = gr.DataFrame(datatype="html", wrap=True) | |
| load_button.click( | |
| fn=open_dataset, | |
| inputs=[dataset_name, subset_select, split_name, batch_size_input, initial_shard_input, only_rgb_checkbox], | |
| outputs=[shard_slider, image_gallery, metadata_table] | |
| ) | |
| shard_slider.release( | |
| fn=open_dataset, | |
| inputs=[dataset_name, subset_select, split_name, batch_size_input, shard_slider, only_rgb_checkbox], | |
| outputs=[shard_slider, image_gallery, metadata_table] | |
| ) | |
| gallery_columns_input.change( | |
| fn=update_gallery_columns, | |
| inputs=[gallery_columns_input], | |
| outputs=[image_gallery] | |
| ) | |
| next_batch_button.click( | |
| fn=get_images, | |
| inputs=[batch_size_input, only_rgb_checkbox], | |
| outputs=[image_gallery, metadata_table] | |
| ) | |
| demo.launch(show_api=False) | |