|
|
""" |
|
|
Hugging Face Hub dataset browser for Video Model Studio. |
|
|
Handles searching, viewing, and downloading datasets from the Hub. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import shutil |
|
|
import tempfile |
|
|
import asyncio |
|
|
import logging |
|
|
import gradio as gr |
|
|
from pathlib import Path |
|
|
from typing import List, Dict, Optional, Tuple, Any, Union, Callable |
|
|
|
|
|
from huggingface_hub import ( |
|
|
HfApi, |
|
|
hf_hub_download, |
|
|
snapshot_download, |
|
|
list_datasets |
|
|
) |
|
|
|
|
|
from vms.config import NORMALIZE_IMAGES_TO, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, DEFAULT_PROMPT_PREFIX |
|
|
from vms.utils import normalize_image, is_image_file, is_video_file, add_prefix_to_caption, webdataset_handler |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class HubDatasetBrowser: |
|
|
"""Handles interactions with Hugging Face Hub datasets""" |
|
|
|
|
|
def __init__(self, hf_api: HfApi): |
|
|
"""Initialize with HfApi instance |
|
|
|
|
|
Args: |
|
|
hf_api: Hugging Face Hub API instance |
|
|
""" |
|
|
self.hf_api = hf_api |
|
|
|
|
|
def search_datasets(self, query: str) -> List[List[str]]: |
|
|
"""Search for datasets on the Hugging Face Hub |
|
|
|
|
|
Args: |
|
|
query: Search query string |
|
|
|
|
|
Returns: |
|
|
List of datasets matching the query [id, title, downloads] |
|
|
Note: We still return all columns internally, but the UI will only display the first column |
|
|
""" |
|
|
try: |
|
|
|
|
|
search_terms = query.strip() if query and query.strip() else "video" |
|
|
logger.info(f"Searching datasets with query: '{search_terms}'") |
|
|
|
|
|
|
|
|
datasets = list(self.hf_api.list_datasets( |
|
|
search=search_terms, |
|
|
limit=50 |
|
|
)) |
|
|
|
|
|
|
|
|
results = [] |
|
|
for ds in datasets: |
|
|
|
|
|
dataset_id = ds.id |
|
|
|
|
|
|
|
|
card_data = getattr(ds, "card_data", None) |
|
|
title = "" |
|
|
|
|
|
if card_data is not None and isinstance(card_data, dict): |
|
|
title = card_data.get("name", "") |
|
|
|
|
|
if not title: |
|
|
|
|
|
title = dataset_id.split("/")[-1] |
|
|
|
|
|
|
|
|
downloads = getattr(ds, "downloads", 0) |
|
|
if downloads is None: |
|
|
downloads = 0 |
|
|
|
|
|
results.append([dataset_id, title, downloads]) |
|
|
|
|
|
|
|
|
results.sort(key=lambda x: x[2] if x[2] is not None else 0, reverse=True) |
|
|
|
|
|
logger.info(f"Found {len(results)} datasets matching '{search_terms}'") |
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error searching datasets: {str(e)}", exc_info=True) |
|
|
return [[f"Error: {str(e)}", "", ""]] |
|
|
|
|
|
def get_dataset_info(self, dataset_id: str) -> Tuple[str, Dict[str, int], Dict[str, List[str]]]: |
|
|
"""Get detailed information about a dataset |
|
|
|
|
|
Args: |
|
|
dataset_id: The dataset ID to get information for |
|
|
|
|
|
Returns: |
|
|
Tuple of (markdown_info, file_counts, file_groups) |
|
|
- markdown_info: Markdown formatted string with dataset information |
|
|
- file_counts: Dictionary with counts of each file type |
|
|
- file_groups: Dictionary with lists of filenames grouped by type |
|
|
""" |
|
|
try: |
|
|
if not dataset_id: |
|
|
logger.warning("No dataset ID provided to get_dataset_info") |
|
|
return "No dataset selected", {}, {} |
|
|
|
|
|
logger.info(f"Getting info for dataset: {dataset_id}") |
|
|
|
|
|
|
|
|
dataset_info = self.hf_api.dataset_info(dataset_id) |
|
|
|
|
|
|
|
|
info_text = f"### {dataset_info.id}\n\n" |
|
|
|
|
|
|
|
|
card_data = getattr(dataset_info, "card_data", None) |
|
|
description = "" |
|
|
|
|
|
if card_data is not None and isinstance(card_data, dict): |
|
|
description = card_data.get("description", "") |
|
|
|
|
|
if description: |
|
|
info_text += f"{description[:500]}{'...' if len(description) > 500 else ''}\n\n" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
file_groups = { |
|
|
"video": [], |
|
|
"webdataset": [] |
|
|
} |
|
|
|
|
|
siblings = getattr(dataset_info, "siblings", None) or [] |
|
|
|
|
|
|
|
|
for s in siblings: |
|
|
if not hasattr(s, 'rfilename'): |
|
|
continue |
|
|
|
|
|
filename = s.rfilename |
|
|
if filename.lower().endswith((".mp4", ".webm")): |
|
|
file_groups["video"].append(filename) |
|
|
elif filename.lower().endswith(".tar"): |
|
|
file_groups["webdataset"].append(filename) |
|
|
|
|
|
|
|
|
file_counts = { |
|
|
"video": len(file_groups["video"]), |
|
|
"webdataset": len(file_groups["webdataset"]) |
|
|
} |
|
|
|
|
|
logger.info(f"Successfully retrieved info for dataset: {dataset_id}") |
|
|
return info_text, file_counts, file_groups |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error getting dataset info: {str(e)}", exc_info=True) |
|
|
return f"Error loading dataset information: {str(e)}", {}, {} |
|
|
|
|
|
async def download_file_group( |
|
|
self, |
|
|
dataset_id: str, |
|
|
file_type: str, |
|
|
enable_splitting: bool, |
|
|
progress_callback: Optional[Callable] = None |
|
|
) -> str: |
|
|
"""Download all files of a specific type from the dataset |
|
|
|
|
|
Args: |
|
|
dataset_id: The dataset ID |
|
|
file_type: Either "video" or "webdataset" |
|
|
enable_splitting: Whether to enable automatic video splitting |
|
|
progress_callback: Optional callback for progress updates |
|
|
|
|
|
Returns: |
|
|
Status message |
|
|
""" |
|
|
try: |
|
|
|
|
|
_, _, file_groups = self.get_dataset_info(dataset_id) |
|
|
|
|
|
|
|
|
files = file_groups.get(file_type, []) |
|
|
|
|
|
if not files: |
|
|
return f"No {file_type} files found in the dataset" |
|
|
|
|
|
logger.info(f"Downloading {len(files)} {file_type} files from dataset {dataset_id}") |
|
|
gr.Info(f"Starting download of {len(files)} {file_type} files from {dataset_id}") |
|
|
|
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(0, desc=f"Starting download of {len(files)} {file_type} files", total=len(files)) |
|
|
|
|
|
|
|
|
video_count = 0 |
|
|
image_count = 0 |
|
|
|
|
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
temp_path = Path(temp_dir) |
|
|
|
|
|
|
|
|
for i, filename in enumerate(files): |
|
|
try: |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback( |
|
|
i, |
|
|
desc=f"Downloading file {i+1}/{len(files)}: {Path(filename).name}", |
|
|
total=len(files) |
|
|
) |
|
|
|
|
|
|
|
|
file_path = hf_hub_download( |
|
|
repo_id=dataset_id, |
|
|
filename=filename, |
|
|
repo_type="dataset", |
|
|
local_dir=temp_path |
|
|
) |
|
|
|
|
|
file_path = Path(file_path) |
|
|
logger.info(f"Downloaded file to {file_path}") |
|
|
|
|
|
|
|
|
|
|
|
if file_type == "video": |
|
|
|
|
|
target_dir = VIDEOS_TO_SPLIT_PATH if enable_splitting else STAGING_PATH |
|
|
target_path = target_dir / file_path.name |
|
|
|
|
|
|
|
|
counter = 1 |
|
|
while target_path.exists(): |
|
|
stem = Path(file_path.name).stem |
|
|
if "___" in stem: |
|
|
base_stem = stem.split("___")[0] |
|
|
else: |
|
|
base_stem = stem |
|
|
target_path = target_dir / f"{base_stem}___{counter}{Path(file_path.name).suffix}" |
|
|
counter += 1 |
|
|
|
|
|
|
|
|
shutil.copy2(file_path, target_path) |
|
|
logger.info(f"Processed video: {file_path.name} -> {target_path.name}") |
|
|
|
|
|
|
|
|
try: |
|
|
txt_filename = f"{Path(filename).stem}.txt" |
|
|
for possible_path in [ |
|
|
Path(filename).with_suffix('.txt').as_posix(), |
|
|
(Path(filename).parent / txt_filename).as_posix(), |
|
|
]: |
|
|
try: |
|
|
txt_path = hf_hub_download( |
|
|
repo_id=dataset_id, |
|
|
filename=possible_path, |
|
|
repo_type="dataset", |
|
|
local_dir=temp_path |
|
|
) |
|
|
shutil.copy2(txt_path, target_path.with_suffix('.txt')) |
|
|
logger.info(f"Copied caption for {file_path.name}") |
|
|
break |
|
|
except Exception: |
|
|
|
|
|
pass |
|
|
except Exception as e: |
|
|
logger.warning(f"Error trying to download caption: {e}") |
|
|
|
|
|
video_count += 1 |
|
|
|
|
|
elif file_type == "webdataset": |
|
|
|
|
|
try: |
|
|
logger.info(f"Processing WebDataset file: {file_path}") |
|
|
vid_count, img_count = webdataset_handler.process_webdataset_shard( |
|
|
file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH |
|
|
) |
|
|
video_count += vid_count |
|
|
image_count += img_count |
|
|
except Exception as e: |
|
|
logger.error(f"Error processing WebDataset file {file_path}: {str(e)}", exc_info=True) |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Error processing file {filename}: {e}") |
|
|
|
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(len(files), desc="Download complete", total=len(files)) |
|
|
|
|
|
|
|
|
if file_type == "video": |
|
|
status_msg = f"Successfully imported {video_count} videos from dataset {dataset_id}" |
|
|
elif file_type == "webdataset": |
|
|
parts = [] |
|
|
if video_count > 0: |
|
|
parts.append(f"{video_count} video{'s' if video_count != 1 else ''}") |
|
|
if image_count > 0: |
|
|
parts.append(f"{image_count} image{'s' if image_count != 1 else ''}") |
|
|
|
|
|
if parts: |
|
|
status_msg = f"Successfully imported {' and '.join(parts)} from WebDataset archives" |
|
|
else: |
|
|
status_msg = f"No media was found in the WebDataset archives" |
|
|
else: |
|
|
status_msg = f"Unknown file type: {file_type}" |
|
|
|
|
|
|
|
|
logger.info(f"✅ Download complete! {status_msg}") |
|
|
|
|
|
gr.Info(f"✅ Download complete! {status_msg}") |
|
|
|
|
|
return status_msg |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error downloading {file_type} files: {str(e)}" |
|
|
logger.error(error_msg, exc_info=True) |
|
|
gr.Error(error_msg) |
|
|
return error_msg |
|
|
|
|
|
async def download_dataset( |
|
|
self, |
|
|
dataset_id: str, |
|
|
enable_splitting: bool, |
|
|
progress_callback: Optional[Callable] = None |
|
|
) -> Tuple[str, str]: |
|
|
"""Download a dataset and process its video/image content |
|
|
|
|
|
Args: |
|
|
dataset_id: The dataset ID to download |
|
|
enable_splitting: Whether to enable automatic video splitting |
|
|
progress_callback: Optional callback for progress tracking |
|
|
|
|
|
Returns: |
|
|
Tuple of (loading_msg, status_msg) |
|
|
""" |
|
|
if not dataset_id: |
|
|
logger.warning("No dataset ID provided for download") |
|
|
return "No dataset selected", "Please select a dataset first" |
|
|
|
|
|
try: |
|
|
logger.info(f"Starting download of dataset: {dataset_id}") |
|
|
loading_msg = f"## Downloading dataset: {dataset_id}\n\nThis may take some time depending on the dataset size..." |
|
|
status_msg = f"Downloading dataset: {dataset_id}..." |
|
|
|
|
|
|
|
|
dataset_info = self.hf_api.dataset_info(dataset_id) |
|
|
|
|
|
|
|
|
video_files = [] |
|
|
tar_files = [] |
|
|
|
|
|
siblings = getattr(dataset_info, "siblings", None) or [] |
|
|
if siblings: |
|
|
video_files = [s.rfilename for s in siblings if hasattr(s, 'rfilename') and s.rfilename.lower().endswith((".mp4", ".webm"))] |
|
|
tar_files = [s.rfilename for s in siblings if hasattr(s, 'rfilename') and s.rfilename.lower().endswith(".tar")] |
|
|
|
|
|
|
|
|
total_files = len(video_files) + len(tar_files) |
|
|
if progress_callback: |
|
|
progress_callback(0, desc=f"Starting download of dataset: {dataset_id}", total=total_files) |
|
|
|
|
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
temp_path = Path(temp_dir) |
|
|
files_processed = 0 |
|
|
|
|
|
|
|
|
if video_files: |
|
|
loading_msg = f"{loading_msg}\n\nDownloading {len(video_files)} video files..." |
|
|
logger.info(f"Downloading {len(video_files)} video files from {dataset_id}") |
|
|
|
|
|
for i, video_file in enumerate(video_files): |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback( |
|
|
files_processed, |
|
|
desc=f"Downloading video {i+1}/{len(video_files)}: {Path(video_file).name}", |
|
|
total=total_files |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
file_path = hf_hub_download( |
|
|
repo_id=dataset_id, |
|
|
filename=video_file, |
|
|
repo_type="dataset", |
|
|
local_dir=temp_path |
|
|
) |
|
|
|
|
|
|
|
|
txt_filename = f"{Path(video_file).stem}.txt" |
|
|
txt_path = None |
|
|
for possible_path in [ |
|
|
Path(video_file).with_suffix('.txt').as_posix(), |
|
|
(Path(video_file).parent / txt_filename).as_posix(), |
|
|
]: |
|
|
try: |
|
|
txt_path = hf_hub_download( |
|
|
repo_id=dataset_id, |
|
|
filename=possible_path, |
|
|
repo_type="dataset", |
|
|
local_dir=temp_path |
|
|
) |
|
|
logger.info(f"Found caption file for {video_file}: {possible_path}") |
|
|
break |
|
|
except Exception as e: |
|
|
|
|
|
logger.debug(f"No caption at {possible_path}: {str(e)}") |
|
|
pass |
|
|
|
|
|
status_msg = f"Downloaded video {i+1}/{len(video_files)} from {dataset_id}" |
|
|
logger.info(status_msg) |
|
|
files_processed += 1 |
|
|
except Exception as e: |
|
|
logger.warning(f"Error downloading {video_file}: {e}") |
|
|
|
|
|
|
|
|
if tar_files: |
|
|
loading_msg = f"{loading_msg}\n\nDownloading {len(tar_files)} WebDataset files..." |
|
|
logger.info(f"Downloading {len(tar_files)} WebDataset files from {dataset_id}") |
|
|
|
|
|
for i, tar_file in enumerate(tar_files): |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback( |
|
|
files_processed, |
|
|
desc=f"Downloading WebDataset {i+1}/{len(tar_files)}: {Path(tar_file).name}", |
|
|
total=total_files |
|
|
) |
|
|
|
|
|
try: |
|
|
file_path = hf_hub_download( |
|
|
repo_id=dataset_id, |
|
|
filename=tar_file, |
|
|
repo_type="dataset", |
|
|
local_dir=temp_path |
|
|
) |
|
|
status_msg = f"Downloaded WebDataset {i+1}/{len(tar_files)} from {dataset_id}" |
|
|
logger.info(status_msg) |
|
|
files_processed += 1 |
|
|
except Exception as e: |
|
|
logger.warning(f"Error downloading {tar_file}: {e}") |
|
|
|
|
|
|
|
|
if not video_files and not tar_files: |
|
|
loading_msg = f"{loading_msg}\n\nDownloading entire dataset repository..." |
|
|
logger.info(f"No specific media files found, downloading entire repository for {dataset_id}") |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(0, desc=f"Downloading entire repository for {dataset_id}", total=1) |
|
|
|
|
|
try: |
|
|
snapshot_download( |
|
|
repo_id=dataset_id, |
|
|
repo_type="dataset", |
|
|
local_dir=temp_path |
|
|
) |
|
|
status_msg = f"Downloaded entire repository for {dataset_id}" |
|
|
logger.info(status_msg) |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(1, desc="Repository download complete", total=1) |
|
|
except Exception as e: |
|
|
logger.error(f"Error downloading dataset snapshot: {e}", exc_info=True) |
|
|
return loading_msg, f"Error downloading dataset: {str(e)}" |
|
|
|
|
|
|
|
|
loading_msg = f"{loading_msg}\n\nProcessing downloaded files..." |
|
|
logger.info(f"Processing downloaded files from {dataset_id}") |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(0, desc="Processing downloaded files", total=100) |
|
|
|
|
|
|
|
|
video_count = 0 |
|
|
image_count = 0 |
|
|
tar_count = 0 |
|
|
|
|
|
|
|
|
async def process_files(): |
|
|
nonlocal video_count, image_count, tar_count |
|
|
|
|
|
|
|
|
file_count = 0 |
|
|
for root, _, files in os.walk(temp_path): |
|
|
file_count += len(files) |
|
|
|
|
|
processed = 0 |
|
|
|
|
|
|
|
|
for root, _, files in os.walk(temp_path): |
|
|
for file in files: |
|
|
file_path = Path(root) / file |
|
|
|
|
|
|
|
|
if progress_callback and processed % 5 == 0: |
|
|
if file_count > 0: |
|
|
progress_percent = int((processed / file_count) * 100) |
|
|
progress_callback( |
|
|
progress_percent, |
|
|
desc=f"Processing files: {processed}/{file_count}", |
|
|
total=100 |
|
|
) |
|
|
|
|
|
|
|
|
if file.lower().endswith((".mp4", ".webm")): |
|
|
|
|
|
target_dir = VIDEOS_TO_SPLIT_PATH if enable_splitting else STAGING_PATH |
|
|
target_path = target_dir / file_path.name |
|
|
|
|
|
|
|
|
counter = 1 |
|
|
while target_path.exists(): |
|
|
stem = Path(file_path.name).stem |
|
|
if "___" in stem: |
|
|
base_stem = stem.split("___")[0] |
|
|
else: |
|
|
base_stem = stem |
|
|
target_path = target_dir / f"{base_stem}___{counter}{Path(file_path.name).suffix}" |
|
|
counter += 1 |
|
|
|
|
|
|
|
|
shutil.copy2(file_path, target_path) |
|
|
logger.info(f"Processed video from dataset: {file_path.name} -> {target_path.name}") |
|
|
|
|
|
|
|
|
txt_path = file_path.with_suffix('.txt') |
|
|
if txt_path.exists(): |
|
|
shutil.copy2(txt_path, target_path.with_suffix('.txt')) |
|
|
logger.info(f"Copied caption for {file_path.name}") |
|
|
|
|
|
video_count += 1 |
|
|
|
|
|
|
|
|
elif is_image_file(file_path): |
|
|
target_path = STAGING_PATH / f"{file_path.stem}.{NORMALIZE_IMAGES_TO}" |
|
|
|
|
|
counter = 1 |
|
|
while target_path.exists(): |
|
|
target_path = STAGING_PATH / f"{file_path.stem}___{counter}.{NORMALIZE_IMAGES_TO}" |
|
|
counter += 1 |
|
|
|
|
|
if normalize_image(file_path, target_path): |
|
|
logger.info(f"Processed image from dataset: {file_path.name} -> {target_path.name}") |
|
|
|
|
|
|
|
|
txt_path = file_path.with_suffix('.txt') |
|
|
if txt_path.exists(): |
|
|
caption = txt_path.read_text() |
|
|
caption = add_prefix_to_caption(caption, DEFAULT_PROMPT_PREFIX) |
|
|
target_path.with_suffix('.txt').write_text(caption) |
|
|
logger.info(f"Processed caption for {file_path.name}") |
|
|
|
|
|
image_count += 1 |
|
|
|
|
|
|
|
|
elif file.lower().endswith(".tar"): |
|
|
|
|
|
try: |
|
|
logger.info(f"Processing WebDataset file from dataset: {file}") |
|
|
vid_count, img_count = webdataset_handler.process_webdataset_shard( |
|
|
file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH |
|
|
) |
|
|
tar_count += 1 |
|
|
video_count += vid_count |
|
|
image_count += img_count |
|
|
logger.info(f"Extracted {vid_count} videos and {img_count} images from {file}") |
|
|
except Exception as e: |
|
|
logger.error(f"Error processing WebDataset file {file_path}: {str(e)}", exc_info=True) |
|
|
|
|
|
processed += 1 |
|
|
|
|
|
|
|
|
await process_files() |
|
|
|
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(100, desc="Processing complete", total=100) |
|
|
|
|
|
|
|
|
parts = [] |
|
|
if video_count > 0: |
|
|
parts.append(f"{video_count} video{'s' if video_count != 1 else ''}") |
|
|
if image_count > 0: |
|
|
parts.append(f"{image_count} image{'s' if image_count != 1 else ''}") |
|
|
if tar_count > 0: |
|
|
parts.append(f"{tar_count} WebDataset archive{'s' if tar_count != 1 else ''}") |
|
|
|
|
|
if parts: |
|
|
status = f"Successfully imported {', '.join(parts)} from dataset {dataset_id}" |
|
|
loading_msg = f"{loading_msg}\n\n✅ Success! {status}" |
|
|
logger.info(status) |
|
|
else: |
|
|
status = f"No supported media files found in dataset {dataset_id}" |
|
|
loading_msg = f"{loading_msg}\n\n⚠️ {status}" |
|
|
logger.warning(status) |
|
|
|
|
|
gr.Info(status) |
|
|
return loading_msg, status |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error downloading dataset {dataset_id}: {str(e)}" |
|
|
logger.error(error_msg, exc_info=True) |
|
|
return f"Error: {error_msg}", error_msg |