Julian Bilcke
commited on
Commit
·
2ae624b
1
Parent(s):
9fb885e
fix issue with scene splitting
Browse files- vms/ui/app_ui.py +8 -5
- vms/ui/project/services/importing/file_upload.py +26 -60
- vms/ui/project/services/importing/hub_dataset.py +2 -2
- vms/ui/project/services/importing/import_service.py +11 -6
- vms/ui/project/services/importing/youtube.py +7 -3
- vms/ui/project/services/splitting.py +4 -1
- vms/ui/project/tabs/import_tab/import_tab.py +27 -12
- vms/ui/project/tabs/import_tab/upload_tab.py +15 -4
- vms/ui/project/tabs/import_tab/youtube_tab.py +7 -3
- vms/ui/project/tabs/preview_tab.py +4 -2
- vms/ui/project/tabs/train_tab.py +14 -7
- vms/utils/webdataset_handler.py +3 -1
vms/ui/app_ui.py
CHANGED
|
@@ -396,11 +396,14 @@ class AppUI:
|
|
| 396 |
model_version_val = available_model_versions[0]
|
| 397 |
logger.info(f"Using first available model version: {model_version_val}")
|
| 398 |
|
| 399 |
-
# IMPORTANT:
|
| 400 |
-
# This is
|
|
|
|
|
|
|
|
|
|
| 401 |
try:
|
| 402 |
-
self.project_tabs["train_tab"].components["model_version"].choices =
|
| 403 |
-
logger.info(f"Updated model_version dropdown choices: {len(
|
| 404 |
except Exception as e:
|
| 405 |
logger.error(f"Error updating model_version dropdown: {str(e)}")
|
| 406 |
else:
|
|
@@ -410,7 +413,7 @@ class AppUI:
|
|
| 410 |
self.project_tabs["train_tab"].components["model_version"].choices = []
|
| 411 |
except Exception as e:
|
| 412 |
logger.error(f"Error setting empty model_version choices: {str(e)}")
|
| 413 |
-
|
| 414 |
# Ensure training_type is a valid display name
|
| 415 |
training_type_val = ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])
|
| 416 |
if training_type_val not in TRAINING_TYPES:
|
|
|
|
| 396 |
model_version_val = available_model_versions[0]
|
| 397 |
logger.info(f"Using first available model version: {model_version_val}")
|
| 398 |
|
| 399 |
+
# IMPORTANT: Create a new list of simple strings for the dropdown choices
|
| 400 |
+
# This ensures each choice is a single string, not a tuple or other structure
|
| 401 |
+
simple_choices = [str(version) for version in available_model_versions]
|
| 402 |
+
|
| 403 |
+
# Update the dropdown choices directly in the UI component
|
| 404 |
try:
|
| 405 |
+
self.project_tabs["train_tab"].components["model_version"].choices = simple_choices
|
| 406 |
+
logger.info(f"Updated model_version dropdown choices: {len(simple_choices)} options")
|
| 407 |
except Exception as e:
|
| 408 |
logger.error(f"Error updating model_version dropdown: {str(e)}")
|
| 409 |
else:
|
|
|
|
| 413 |
self.project_tabs["train_tab"].components["model_version"].choices = []
|
| 414 |
except Exception as e:
|
| 415 |
logger.error(f"Error setting empty model_version choices: {str(e)}")
|
| 416 |
+
|
| 417 |
# Ensure training_type is a valid display name
|
| 418 |
training_type_val = ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])
|
| 419 |
if training_type_val not in TRAINING_TYPES:
|
vms/ui/project/services/importing/file_upload.py
CHANGED
|
@@ -22,20 +22,23 @@ logger = logging.getLogger(__name__)
|
|
| 22 |
class FileUploadHandler:
|
| 23 |
"""Handles processing of uploaded files"""
|
| 24 |
|
| 25 |
-
def process_uploaded_files(self, file_paths: List[str]) -> str:
|
| 26 |
"""Process uploaded file (ZIP, TAR, MP4, or image)
|
| 27 |
|
| 28 |
Args:
|
| 29 |
file_paths: File paths to the uploaded files from Gradio
|
|
|
|
| 30 |
|
| 31 |
Returns:
|
| 32 |
Status message string
|
| 33 |
"""
|
|
|
|
| 34 |
if not file_paths or len(file_paths) == 0:
|
| 35 |
logger.warning("No files provided to process_uploaded_files")
|
| 36 |
return "No files provided"
|
| 37 |
-
|
| 38 |
for file_path in file_paths:
|
|
|
|
| 39 |
file_path = Path(file_path)
|
| 40 |
try:
|
| 41 |
original_name = file_path.name
|
|
@@ -45,11 +48,11 @@ class FileUploadHandler:
|
|
| 45 |
file_ext = file_path.suffix.lower()
|
| 46 |
|
| 47 |
if file_ext == '.zip':
|
| 48 |
-
return self.process_zip_file(file_path)
|
| 49 |
elif file_ext == '.tar':
|
| 50 |
-
return self.process_tar_file(file_path)
|
| 51 |
elif file_ext == '.mp4' or file_ext == '.webm':
|
| 52 |
-
return self.process_mp4_file(file_path, original_name)
|
| 53 |
elif is_image_file(file_path):
|
| 54 |
return self.process_image_file(file_path, original_name)
|
| 55 |
else:
|
|
@@ -60,56 +63,12 @@ class FileUploadHandler:
|
|
| 60 |
logger.error(f"Error processing file {file_path}: {str(e)}", exc_info=True)
|
| 61 |
raise gr.Error(f"Error processing file: {str(e)}")
|
| 62 |
|
| 63 |
-
def
|
| 64 |
-
"""Process a single image file
|
| 65 |
-
|
| 66 |
-
Args:
|
| 67 |
-
file_path: Path to the image
|
| 68 |
-
original_name: Original filename
|
| 69 |
-
|
| 70 |
-
Returns:
|
| 71 |
-
Status message string
|
| 72 |
-
"""
|
| 73 |
-
try:
|
| 74 |
-
# Create a unique filename with configured extension
|
| 75 |
-
stem = Path(original_name).stem
|
| 76 |
-
target_path = STAGING_PATH / f"{stem}.{NORMALIZE_IMAGES_TO}"
|
| 77 |
-
|
| 78 |
-
# If file already exists, add number suffix
|
| 79 |
-
counter = 1
|
| 80 |
-
while target_path.exists():
|
| 81 |
-
target_path = STAGING_PATH / f"{stem}___{counter}.{NORMALIZE_IMAGES_TO}"
|
| 82 |
-
counter += 1
|
| 83 |
-
|
| 84 |
-
logger.info(f"Processing image file: {original_name} -> {target_path}")
|
| 85 |
-
|
| 86 |
-
# Convert to normalized format and remove black bars
|
| 87 |
-
success = normalize_image(file_path, target_path)
|
| 88 |
-
|
| 89 |
-
if not success:
|
| 90 |
-
logger.error(f"Failed to process image: {original_name}")
|
| 91 |
-
raise gr.Error(f"Failed to process image: {original_name}")
|
| 92 |
-
|
| 93 |
-
# Handle caption
|
| 94 |
-
src_caption_path = file_path.with_suffix('.txt')
|
| 95 |
-
if src_caption_path.exists():
|
| 96 |
-
caption = src_caption_path.read_text()
|
| 97 |
-
caption = add_prefix_to_caption(caption, DEFAULT_PROMPT_PREFIX)
|
| 98 |
-
target_path.with_suffix('.txt').write_text(caption)
|
| 99 |
-
|
| 100 |
-
logger.info(f"Successfully stored image: {target_path.name}")
|
| 101 |
-
gr.Info(f"Successfully stored image: {target_path.name}")
|
| 102 |
-
return f"Successfully stored image: {target_path.name}"
|
| 103 |
-
|
| 104 |
-
except Exception as e:
|
| 105 |
-
logger.error(f"Error processing image file: {str(e)}", exc_info=True)
|
| 106 |
-
raise gr.Error(f"Error processing image file: {str(e)}")
|
| 107 |
-
|
| 108 |
-
def process_zip_file(self, file_path: Path) -> str:
|
| 109 |
"""Process uploaded ZIP file containing media files or WebDataset tar files
|
| 110 |
|
| 111 |
Args:
|
| 112 |
file_path: Path to the uploaded ZIP file
|
|
|
|
| 113 |
|
| 114 |
Returns:
|
| 115 |
Status message string
|
|
@@ -143,17 +102,18 @@ class FileUploadHandler:
|
|
| 143 |
logger.info(f"Processing WebDataset archive from ZIP: {file}")
|
| 144 |
# Process WebDataset shard
|
| 145 |
vid_count, img_count = webdataset_handler.process_webdataset_shard(
|
| 146 |
-
file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH
|
| 147 |
)
|
| 148 |
video_count += vid_count
|
| 149 |
image_count += img_count
|
| 150 |
tar_count += 1
|
| 151 |
elif is_video_file(file_path):
|
| 152 |
-
#
|
| 153 |
-
|
|
|
|
| 154 |
counter = 1
|
| 155 |
while target_path.exists():
|
| 156 |
-
target_path =
|
| 157 |
counter += 1
|
| 158 |
shutil.copy2(file_path, target_path)
|
| 159 |
logger.info(f"Extracted video from ZIP: {file} -> {target_path.name}")
|
|
@@ -208,11 +168,12 @@ class FileUploadHandler:
|
|
| 208 |
logger.error(f"Error processing ZIP: {str(e)}", exc_info=True)
|
| 209 |
raise gr.Error(f"Error processing ZIP: {str(e)}")
|
| 210 |
|
| 211 |
-
def process_tar_file(self, file_path: Path) -> str:
|
| 212 |
"""Process a WebDataset tar file
|
| 213 |
|
| 214 |
Args:
|
| 215 |
file_path: Path to the uploaded tar file
|
|
|
|
| 216 |
|
| 217 |
Returns:
|
| 218 |
Status message string
|
|
@@ -220,7 +181,7 @@ class FileUploadHandler:
|
|
| 220 |
try:
|
| 221 |
logger.info(f"Processing WebDataset TAR file: {file_path}")
|
| 222 |
video_count, image_count = webdataset_handler.process_webdataset_shard(
|
| 223 |
-
file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH
|
| 224 |
)
|
| 225 |
|
| 226 |
# Generate status message
|
|
@@ -243,25 +204,30 @@ class FileUploadHandler:
|
|
| 243 |
logger.error(f"Error processing WebDataset tar file: {str(e)}", exc_info=True)
|
| 244 |
raise gr.Error(f"Error processing WebDataset tar file: {str(e)}")
|
| 245 |
|
| 246 |
-
def process_mp4_file(self, file_path: Path, original_name: str) -> str:
|
| 247 |
"""Process a single video file
|
| 248 |
|
| 249 |
Args:
|
| 250 |
file_path: Path to the file
|
| 251 |
original_name: Original filename
|
|
|
|
| 252 |
|
| 253 |
Returns:
|
| 254 |
Status message string
|
| 255 |
"""
|
|
|
|
| 256 |
try:
|
|
|
|
|
|
|
|
|
|
| 257 |
# Create a unique filename
|
| 258 |
-
target_path =
|
| 259 |
|
| 260 |
# If file already exists, add number suffix
|
| 261 |
counter = 1
|
| 262 |
while target_path.exists():
|
| 263 |
stem = Path(original_name).stem
|
| 264 |
-
target_path =
|
| 265 |
counter += 1
|
| 266 |
|
| 267 |
logger.info(f"Processing video file: {original_name} -> {target_path}")
|
|
|
|
| 22 |
class FileUploadHandler:
|
| 23 |
"""Handles processing of uploaded files"""
|
| 24 |
|
| 25 |
+
def process_uploaded_files(self, file_paths: List[str], enable_splitting: bool) -> str:
|
| 26 |
"""Process uploaded file (ZIP, TAR, MP4, or image)
|
| 27 |
|
| 28 |
Args:
|
| 29 |
file_paths: File paths to the uploaded files from Gradio
|
| 30 |
+
enable_splitting: Whether to enable automatic video splitting
|
| 31 |
|
| 32 |
Returns:
|
| 33 |
Status message string
|
| 34 |
"""
|
| 35 |
+
print(f"process_uploaded_files called with enable_splitting={enable_splitting} and file_paths = {str(file_paths)}")
|
| 36 |
if not file_paths or len(file_paths) == 0:
|
| 37 |
logger.warning("No files provided to process_uploaded_files")
|
| 38 |
return "No files provided"
|
| 39 |
+
|
| 40 |
for file_path in file_paths:
|
| 41 |
+
print(f" - {str(file_path)}")
|
| 42 |
file_path = Path(file_path)
|
| 43 |
try:
|
| 44 |
original_name = file_path.name
|
|
|
|
| 48 |
file_ext = file_path.suffix.lower()
|
| 49 |
|
| 50 |
if file_ext == '.zip':
|
| 51 |
+
return self.process_zip_file(file_path, enable_splitting)
|
| 52 |
elif file_ext == '.tar':
|
| 53 |
+
return self.process_tar_file(file_path, enable_splitting)
|
| 54 |
elif file_ext == '.mp4' or file_ext == '.webm':
|
| 55 |
+
return self.process_mp4_file(file_path, original_name, enable_splitting)
|
| 56 |
elif is_image_file(file_path):
|
| 57 |
return self.process_image_file(file_path, original_name)
|
| 58 |
else:
|
|
|
|
| 63 |
logger.error(f"Error processing file {file_path}: {str(e)}", exc_info=True)
|
| 64 |
raise gr.Error(f"Error processing file: {str(e)}")
|
| 65 |
|
| 66 |
+
def process_zip_file(self, file_path: Path, enable_splitting: bool) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
"""Process uploaded ZIP file containing media files or WebDataset tar files
|
| 68 |
|
| 69 |
Args:
|
| 70 |
file_path: Path to the uploaded ZIP file
|
| 71 |
+
enable_splitting: Whether to enable automatic video splitting
|
| 72 |
|
| 73 |
Returns:
|
| 74 |
Status message string
|
|
|
|
| 102 |
logger.info(f"Processing WebDataset archive from ZIP: {file}")
|
| 103 |
# Process WebDataset shard
|
| 104 |
vid_count, img_count = webdataset_handler.process_webdataset_shard(
|
| 105 |
+
file_path, VIDEOS_TO_SPLIT_PATH if enable_splitting else STAGING_PATH, STAGING_PATH
|
| 106 |
)
|
| 107 |
video_count += vid_count
|
| 108 |
image_count += img_count
|
| 109 |
tar_count += 1
|
| 110 |
elif is_video_file(file_path):
|
| 111 |
+
# Choose target directory based on auto-splitting setting
|
| 112 |
+
target_dir = VIDEOS_TO_SPLIT_PATH if enable_splitting else STAGING_PATH
|
| 113 |
+
target_path = target_dir / file_path.name
|
| 114 |
counter = 1
|
| 115 |
while target_path.exists():
|
| 116 |
+
target_path = target_dir / f"{file_path.stem}___{counter}{file_path.suffix}"
|
| 117 |
counter += 1
|
| 118 |
shutil.copy2(file_path, target_path)
|
| 119 |
logger.info(f"Extracted video from ZIP: {file} -> {target_path.name}")
|
|
|
|
| 168 |
logger.error(f"Error processing ZIP: {str(e)}", exc_info=True)
|
| 169 |
raise gr.Error(f"Error processing ZIP: {str(e)}")
|
| 170 |
|
| 171 |
+
def process_tar_file(self, file_path: Path, enable_splitting: bool) -> str:
|
| 172 |
"""Process a WebDataset tar file
|
| 173 |
|
| 174 |
Args:
|
| 175 |
file_path: Path to the uploaded tar file
|
| 176 |
+
enable_splitting: Whether to enable automatic video splitting
|
| 177 |
|
| 178 |
Returns:
|
| 179 |
Status message string
|
|
|
|
| 181 |
try:
|
| 182 |
logger.info(f"Processing WebDataset TAR file: {file_path}")
|
| 183 |
video_count, image_count = webdataset_handler.process_webdataset_shard(
|
| 184 |
+
file_path, VIDEOS_TO_SPLIT_PATH if enable_splitting else STAGING_PATH, STAGING_PATH
|
| 185 |
)
|
| 186 |
|
| 187 |
# Generate status message
|
|
|
|
| 204 |
logger.error(f"Error processing WebDataset tar file: {str(e)}", exc_info=True)
|
| 205 |
raise gr.Error(f"Error processing WebDataset tar file: {str(e)}")
|
| 206 |
|
| 207 |
+
def process_mp4_file(self, file_path: Path, original_name: str, enable_splitting: bool) -> str:
|
| 208 |
"""Process a single video file
|
| 209 |
|
| 210 |
Args:
|
| 211 |
file_path: Path to the file
|
| 212 |
original_name: Original filename
|
| 213 |
+
enable_splitting: Whether to enable automatic video splitting
|
| 214 |
|
| 215 |
Returns:
|
| 216 |
Status message string
|
| 217 |
"""
|
| 218 |
+
print(f"process_mp4_file(self, file_path={str(file_path)}, original_name={str(original_name)}, enable_splitting={enable_splitting})")
|
| 219 |
try:
|
| 220 |
+
# Choose target directory based on auto-splitting setting
|
| 221 |
+
target_dir = VIDEOS_TO_SPLIT_PATH if enable_splitting else STAGING_PATH
|
| 222 |
+
print(f"target_dir = {target_dir}")
|
| 223 |
# Create a unique filename
|
| 224 |
+
target_path = target_dir / original_name
|
| 225 |
|
| 226 |
# If file already exists, add number suffix
|
| 227 |
counter = 1
|
| 228 |
while target_path.exists():
|
| 229 |
stem = Path(original_name).stem
|
| 230 |
+
target_path = target_dir / f"{stem}___{counter}.mp4"
|
| 231 |
counter += 1
|
| 232 |
|
| 233 |
logger.info(f"Processing video file: {original_name} -> {target_path}")
|
vms/ui/project/services/importing/hub_dataset.py
CHANGED
|
@@ -168,7 +168,7 @@ class HubDatasetBrowser:
|
|
| 168 |
self,
|
| 169 |
dataset_id: str,
|
| 170 |
file_type: str,
|
| 171 |
-
enable_splitting: bool
|
| 172 |
progress_callback: Optional[Callable] = None
|
| 173 |
) -> str:
|
| 174 |
"""Download all files of a specific type from the dataset
|
|
@@ -328,7 +328,7 @@ class HubDatasetBrowser:
|
|
| 328 |
async def download_dataset(
|
| 329 |
self,
|
| 330 |
dataset_id: str,
|
| 331 |
-
enable_splitting: bool
|
| 332 |
progress_callback: Optional[Callable] = None
|
| 333 |
) -> Tuple[str, str]:
|
| 334 |
"""Download a dataset and process its video/image content
|
|
|
|
| 168 |
self,
|
| 169 |
dataset_id: str,
|
| 170 |
file_type: str,
|
| 171 |
+
enable_splitting: bool,
|
| 172 |
progress_callback: Optional[Callable] = None
|
| 173 |
) -> str:
|
| 174 |
"""Download all files of a specific type from the dataset
|
|
|
|
| 328 |
async def download_dataset(
|
| 329 |
self,
|
| 330 |
dataset_id: str,
|
| 331 |
+
enable_splitting: bool,
|
| 332 |
progress_callback: Optional[Callable] = None
|
| 333 |
) -> Tuple[str, str]:
|
| 334 |
"""Download a dataset and process its video/image content
|
vms/ui/project/services/importing/import_service.py
CHANGED
|
@@ -28,32 +28,37 @@ class ImportingService:
|
|
| 28 |
self.youtube_handler = YouTubeDownloader()
|
| 29 |
self.hub_browser = HubDatasetBrowser(self.hf_api)
|
| 30 |
|
| 31 |
-
def process_uploaded_files(self, file_paths: List[str]) -> str:
|
| 32 |
"""Process uploaded file (ZIP, TAR, MP4, or image)
|
| 33 |
|
| 34 |
Args:
|
| 35 |
file_paths: File paths to the uploaded files from Gradio
|
|
|
|
| 36 |
|
| 37 |
Returns:
|
| 38 |
Status message string
|
| 39 |
"""
|
|
|
|
| 40 |
if not file_paths or len(file_paths) == 0:
|
| 41 |
logger.warning("No files provided to process_uploaded_files")
|
| 42 |
return "No files provided"
|
| 43 |
|
| 44 |
-
|
|
|
|
|
|
|
| 45 |
|
| 46 |
-
def download_youtube_video(self, url: str, progress=None) -> str:
|
| 47 |
"""Download a video from YouTube
|
| 48 |
|
| 49 |
Args:
|
| 50 |
url: YouTube video URL
|
|
|
|
| 51 |
progress: Optional Gradio progress indicator
|
| 52 |
|
| 53 |
Returns:
|
| 54 |
Status message string
|
| 55 |
"""
|
| 56 |
-
return self.youtube_handler.download_video(url, progress)
|
| 57 |
|
| 58 |
def search_datasets(self, query: str) -> List[List[str]]:
|
| 59 |
"""Search for datasets on the Hugging Face Hub
|
|
@@ -80,7 +85,7 @@ class ImportingService:
|
|
| 80 |
async def download_dataset(
|
| 81 |
self,
|
| 82 |
dataset_id: str,
|
| 83 |
-
enable_splitting: bool
|
| 84 |
progress_callback: Optional[Callable] = None
|
| 85 |
) -> Tuple[str, str]:
|
| 86 |
"""Download a dataset and process its video/image content
|
|
@@ -99,7 +104,7 @@ class ImportingService:
|
|
| 99 |
self,
|
| 100 |
dataset_id: str,
|
| 101 |
file_type: str,
|
| 102 |
-
enable_splitting: bool
|
| 103 |
progress_callback: Optional[Callable] = None
|
| 104 |
) -> str:
|
| 105 |
"""Download a group of files (videos or WebDatasets)
|
|
|
|
| 28 |
self.youtube_handler = YouTubeDownloader()
|
| 29 |
self.hub_browser = HubDatasetBrowser(self.hf_api)
|
| 30 |
|
| 31 |
+
def process_uploaded_files(self, file_paths: List[str], enable_splitting: bool) -> str:
|
| 32 |
"""Process uploaded file (ZIP, TAR, MP4, or image)
|
| 33 |
|
| 34 |
Args:
|
| 35 |
file_paths: File paths to the uploaded files from Gradio
|
| 36 |
+
enable_splitting: Whether to enable automatic video splitting
|
| 37 |
|
| 38 |
Returns:
|
| 39 |
Status message string
|
| 40 |
"""
|
| 41 |
+
print(f"process_uploaded_files(..., enable_splitting = { enable_splitting})")
|
| 42 |
if not file_paths or len(file_paths) == 0:
|
| 43 |
logger.warning("No files provided to process_uploaded_files")
|
| 44 |
return "No files provided"
|
| 45 |
|
| 46 |
+
print(f"process_uploaded_files(..., enable_splitting = {enable_splitting:})")
|
| 47 |
+
print(f"process_uploaded_files: calling self.file_handler.process_uploaded_files")
|
| 48 |
+
return self.file_handler.process_uploaded_files(file_paths, enable_splitting)
|
| 49 |
|
| 50 |
+
def download_youtube_video(self, url: str, enable_splitting: bool, progress=None) -> str:
|
| 51 |
"""Download a video from YouTube
|
| 52 |
|
| 53 |
Args:
|
| 54 |
url: YouTube video URL
|
| 55 |
+
enable_splitting: Whether to enable automatic video splitting
|
| 56 |
progress: Optional Gradio progress indicator
|
| 57 |
|
| 58 |
Returns:
|
| 59 |
Status message string
|
| 60 |
"""
|
| 61 |
+
return self.youtube_handler.download_video(url, enable_splitting, progress)
|
| 62 |
|
| 63 |
def search_datasets(self, query: str) -> List[List[str]]:
|
| 64 |
"""Search for datasets on the Hugging Face Hub
|
|
|
|
| 85 |
async def download_dataset(
|
| 86 |
self,
|
| 87 |
dataset_id: str,
|
| 88 |
+
enable_splitting: bool,
|
| 89 |
progress_callback: Optional[Callable] = None
|
| 90 |
) -> Tuple[str, str]:
|
| 91 |
"""Download a dataset and process its video/image content
|
|
|
|
| 104 |
self,
|
| 105 |
dataset_id: str,
|
| 106 |
file_type: str,
|
| 107 |
+
enable_splitting: bool,
|
| 108 |
progress_callback: Optional[Callable] = None
|
| 109 |
) -> str:
|
| 110 |
"""Download a group of files (videos or WebDatasets)
|
vms/ui/project/services/importing/youtube.py
CHANGED
|
@@ -17,11 +17,12 @@ logger = logging.getLogger(__name__)
|
|
| 17 |
class YouTubeDownloader:
|
| 18 |
"""Handles downloading videos from YouTube"""
|
| 19 |
|
| 20 |
-
def download_video(self, url: str, progress: Optional[Callable] = None) -> str:
|
| 21 |
"""Download a video from YouTube
|
| 22 |
|
| 23 |
Args:
|
| 24 |
url: YouTube video URL
|
|
|
|
| 25 |
progress: Optional Gradio progress indicator
|
| 26 |
|
| 27 |
Returns:
|
|
@@ -40,7 +41,10 @@ class YouTubeDownloader:
|
|
| 40 |
if progress else None)
|
| 41 |
|
| 42 |
video_id = yt.video_id
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
# Download highest quality progressive MP4
|
| 46 |
if progress:
|
|
@@ -58,7 +62,7 @@ class YouTubeDownloader:
|
|
| 58 |
logger.info("Starting YouTube video download...")
|
| 59 |
progress(0, desc="Starting download...")
|
| 60 |
|
| 61 |
-
video.download(output_path=str(
|
| 62 |
|
| 63 |
# Update UI
|
| 64 |
if progress:
|
|
|
|
| 17 |
class YouTubeDownloader:
|
| 18 |
"""Handles downloading videos from YouTube"""
|
| 19 |
|
| 20 |
+
def download_video(self, url: str, enable_splitting: bool, progress: Optional[Callable] = None) -> str:
|
| 21 |
"""Download a video from YouTube
|
| 22 |
|
| 23 |
Args:
|
| 24 |
url: YouTube video URL
|
| 25 |
+
enable_splitting: Whether to enable automatic video splitting
|
| 26 |
progress: Optional Gradio progress indicator
|
| 27 |
|
| 28 |
Returns:
|
|
|
|
| 41 |
if progress else None)
|
| 42 |
|
| 43 |
video_id = yt.video_id
|
| 44 |
+
|
| 45 |
+
# Choose target directory based on auto-splitting setting
|
| 46 |
+
target_dir = VIDEOS_TO_SPLIT_PATH if enable_splitting else STAGING_PATH
|
| 47 |
+
output_path = target_dir / f"{video_id}.mp4"
|
| 48 |
|
| 49 |
# Download highest quality progressive MP4
|
| 50 |
if progress:
|
|
|
|
| 62 |
logger.info("Starting YouTube video download...")
|
| 63 |
progress(0, desc="Starting download...")
|
| 64 |
|
| 65 |
+
video.download(output_path=str(target_dir), filename=f"{video_id}.mp4")
|
| 66 |
|
| 67 |
# Update UI
|
| 68 |
if progress:
|
vms/ui/project/services/splitting.py
CHANGED
|
@@ -63,7 +63,7 @@ class SplittingService:
|
|
| 63 |
"""Process a single video file to detect and split scenes"""
|
| 64 |
try:
|
| 65 |
self._processing_status[video_path.name] = f'Processing video "{video_path.name}"...'
|
| 66 |
-
|
| 67 |
parent_caption_path = video_path.with_suffix('.txt')
|
| 68 |
# Create output path for split videos
|
| 69 |
base_name, _ = extract_scene_info(video_path.name)
|
|
@@ -180,6 +180,7 @@ class SplittingService:
|
|
| 180 |
|
| 181 |
async def start_processing(self, enable_splitting: bool) -> None:
|
| 182 |
"""Start background processing of unprocessed videos"""
|
|
|
|
| 183 |
if self.processing:
|
| 184 |
return
|
| 185 |
|
|
@@ -188,6 +189,8 @@ class SplittingService:
|
|
| 188 |
# Process each video
|
| 189 |
for video_file in VIDEOS_TO_SPLIT_PATH.glob("*.mp4"):
|
| 190 |
self._current_file = video_file.name
|
|
|
|
|
|
|
| 191 |
await self.process_video(video_file, enable_splitting)
|
| 192 |
|
| 193 |
finally:
|
|
|
|
| 63 |
"""Process a single video file to detect and split scenes"""
|
| 64 |
try:
|
| 65 |
self._processing_status[video_path.name] = f'Processing video "{video_path.name}"...'
|
| 66 |
+
print(f'Going to split scenes for video "{video_path.name}"...')
|
| 67 |
parent_caption_path = video_path.with_suffix('.txt')
|
| 68 |
# Create output path for split videos
|
| 69 |
base_name, _ = extract_scene_info(video_path.name)
|
|
|
|
| 180 |
|
| 181 |
async def start_processing(self, enable_splitting: bool) -> None:
|
| 182 |
"""Start background processing of unprocessed videos"""
|
| 183 |
+
#print(f"start_processing(enable_splitting={enable_splitting}), self.processing = {self.processing}")
|
| 184 |
if self.processing:
|
| 185 |
return
|
| 186 |
|
|
|
|
| 189 |
# Process each video
|
| 190 |
for video_file in VIDEOS_TO_SPLIT_PATH.glob("*.mp4"):
|
| 191 |
self._current_file = video_file.name
|
| 192 |
+
#print(f"calling await self.process_video(video_file, {enable_splitting})")
|
| 193 |
+
|
| 194 |
await self.process_video(video_file, enable_splitting)
|
| 195 |
|
| 196 |
finally:
|
vms/ui/project/tabs/import_tab/import_tab.py
CHANGED
|
@@ -90,25 +90,37 @@ class ImportTab(BaseTab):
|
|
| 90 |
self.youtube_tab.connect_events()
|
| 91 |
self.hub_tab.connect_events()
|
| 92 |
|
| 93 |
-
def on_import_success(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
"""Handle successful import of files"""
|
|
|
|
| 95 |
# If splitting is disabled, we need to directly move videos to staging
|
| 96 |
-
if
|
| 97 |
-
#
|
| 98 |
-
self._start_copy_to_staging_bg()
|
| 99 |
-
msg = "Copying videos to staging directory without splitting..."
|
| 100 |
-
else:
|
| 101 |
# Start scene detection if not already running and there are videos to process
|
| 102 |
if not self.app.splitting.is_processing():
|
|
|
|
| 103 |
# Start the scene detection in a separate thread
|
| 104 |
self._start_scene_detection_bg(enable_splitting)
|
| 105 |
msg = "Starting automatic scene detection..."
|
| 106 |
else:
|
| 107 |
msg = "Scene detection already running..."
|
| 108 |
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
# Start auto-captioning if enabled
|
| 113 |
if enable_automatic_content_captioning:
|
| 114 |
self._start_captioning_bg(DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, prompt_prefix)
|
|
@@ -122,8 +134,9 @@ class ImportTab(BaseTab):
|
|
| 122 |
logger.warning("Cannot switch tabs - project_tabs_component not available")
|
| 123 |
return None, msg
|
| 124 |
|
| 125 |
-
def _start_scene_detection_bg(self, enable_splitting):
|
| 126 |
"""Start scene detection in a background thread"""
|
|
|
|
| 127 |
def run_async_in_thread():
|
| 128 |
loop = asyncio.new_event_loop()
|
| 129 |
asyncio.set_event_loop(loop)
|
|
@@ -207,11 +220,13 @@ class ImportTab(BaseTab):
|
|
| 207 |
thread.daemon = True
|
| 208 |
thread.start()
|
| 209 |
|
| 210 |
-
async def update_titles_after_import(self, enable_splitting, enable_automatic_content_captioning, prompt_prefix):
|
| 211 |
"""Handle post-import updates including titles"""
|
| 212 |
# Call the non-async version since we need to return immediately for the UI
|
| 213 |
tabs, status_msg = self.on_import_success(
|
| 214 |
-
enable_splitting,
|
|
|
|
|
|
|
| 215 |
)
|
| 216 |
|
| 217 |
# Get updated titles
|
|
|
|
| 90 |
self.youtube_tab.connect_events()
|
| 91 |
self.hub_tab.connect_events()
|
| 92 |
|
| 93 |
+
def on_import_success(
|
| 94 |
+
self,
|
| 95 |
+
enable_splitting: bool,
|
| 96 |
+
enable_automatic_content_captioning: bool,
|
| 97 |
+
prompt_prefix: str
|
| 98 |
+
):
|
| 99 |
"""Handle successful import of files"""
|
| 100 |
+
#print(f"on_import_success(self, enable_splitting={enable_splitting}, enable_automatic_content_captioning={enable_automatic_content_captioning}, prompt_prefix={prompt_prefix})")
|
| 101 |
# If splitting is disabled, we need to directly move videos to staging
|
| 102 |
+
if enable_splitting:
|
| 103 |
+
#print("on_import_success: -> splitting enabled!")
|
|
|
|
|
|
|
|
|
|
| 104 |
# Start scene detection if not already running and there are videos to process
|
| 105 |
if not self.app.splitting.is_processing():
|
| 106 |
+
#print("on_import_success: -> calling self._start_scene_detection_bg(enable_splitting)")
|
| 107 |
# Start the scene detection in a separate thread
|
| 108 |
self._start_scene_detection_bg(enable_splitting)
|
| 109 |
msg = "Starting automatic scene detection..."
|
| 110 |
else:
|
| 111 |
msg = "Scene detection already running..."
|
| 112 |
|
| 113 |
+
# Copy files to training directory
|
| 114 |
+
self.app.tabs["caption_tab"].copy_files_to_training_dir(prompt_prefix)
|
| 115 |
+
else:
|
| 116 |
+
#print("on_import_success: -> splitting NOT enabled")
|
| 117 |
+
# Copy files without splitting
|
| 118 |
+
self._start_copy_to_staging_bg()
|
| 119 |
+
msg = "Copying videos to staging directory without splitting..."
|
| 120 |
+
|
| 121 |
+
# Also immediately copy to training directory
|
| 122 |
+
self.app.tabs["caption_tab"].copy_files_to_training_dir(prompt_prefix)
|
| 123 |
+
|
| 124 |
# Start auto-captioning if enabled
|
| 125 |
if enable_automatic_content_captioning:
|
| 126 |
self._start_captioning_bg(DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, prompt_prefix)
|
|
|
|
| 134 |
logger.warning("Cannot switch tabs - project_tabs_component not available")
|
| 135 |
return None, msg
|
| 136 |
|
| 137 |
+
def _start_scene_detection_bg(self, enable_splitting: bool):
|
| 138 |
"""Start scene detection in a background thread"""
|
| 139 |
+
print(f"_start_scene_detection_bg(enable_splitting={enable_splitting})")
|
| 140 |
def run_async_in_thread():
|
| 141 |
loop = asyncio.new_event_loop()
|
| 142 |
asyncio.set_event_loop(loop)
|
|
|
|
| 220 |
thread.daemon = True
|
| 221 |
thread.start()
|
| 222 |
|
| 223 |
+
async def update_titles_after_import(self, enable_splitting: bool, enable_automatic_content_captioning: bool, prompt_prefix: str):
|
| 224 |
"""Handle post-import updates including titles"""
|
| 225 |
# Call the non-async version since we need to return immediately for the UI
|
| 226 |
tabs, status_msg = self.on_import_success(
|
| 227 |
+
enable_splitting,
|
| 228 |
+
enable_automatic_content_captioning,
|
| 229 |
+
prompt_prefix
|
| 230 |
)
|
| 231 |
|
| 232 |
# Get updated titles
|
vms/ui/project/tabs/import_tab/upload_tab.py
CHANGED
|
@@ -62,11 +62,22 @@ class UploadTab(BaseTab):
|
|
| 62 |
logger.warning("import_status component is not set in UploadTab")
|
| 63 |
return
|
| 64 |
|
| 65 |
-
# File upload event
|
| 66 |
upload_event = self.components["files"].upload(
|
| 67 |
-
fn=
|
| 68 |
-
inputs=[self.components["files"]],
|
| 69 |
outputs=[self.components["import_status"]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
)
|
| 71 |
|
| 72 |
# Only add success handler if all required components exist
|
|
@@ -102,4 +113,4 @@ class UploadTab(BaseTab):
|
|
| 102 |
)
|
| 103 |
except (AttributeError, KeyError) as e:
|
| 104 |
logger.error(f"Error connecting event handlers in UploadTab: {str(e)}")
|
| 105 |
-
# Continue without the success handler
|
|
|
|
| 62 |
logger.warning("import_status component is not set in UploadTab")
|
| 63 |
return
|
| 64 |
|
| 65 |
+
# File upload event with enable_splitting parameter
|
| 66 |
upload_event = self.components["files"].upload(
|
| 67 |
+
fn=self.app.importing.process_uploaded_files,
|
| 68 |
+
inputs=[self.components["files"], self.components["enable_automatic_video_split"]],
|
| 69 |
outputs=[self.components["import_status"]]
|
| 70 |
+
).success(
|
| 71 |
+
fn=self.app.tabs["import_tab"].on_import_success,
|
| 72 |
+
inputs=[
|
| 73 |
+
self.components["enable_automatic_video_split"],
|
| 74 |
+
self.components["enable_automatic_content_captioning"],
|
| 75 |
+
self.app.tabs["caption_tab"].components["custom_prompt_prefix"]
|
| 76 |
+
],
|
| 77 |
+
outputs=[
|
| 78 |
+
self.app.project_tabs_component,
|
| 79 |
+
self.components["import_status"]
|
| 80 |
+
]
|
| 81 |
)
|
| 82 |
|
| 83 |
# Only add success handler if all required components exist
|
|
|
|
| 113 |
)
|
| 114 |
except (AttributeError, KeyError) as e:
|
| 115 |
logger.error(f"Error connecting event handlers in UploadTab: {str(e)}")
|
| 116 |
+
# Continue without the success handler
|
vms/ui/project/tabs/import_tab/youtube_tab.py
CHANGED
|
@@ -83,8 +83,8 @@ class YouTubeTab(BaseTab):
|
|
| 83 |
|
| 84 |
# YouTube download event
|
| 85 |
download_event = self.components["youtube_download_btn"].click(
|
| 86 |
-
fn=self.
|
| 87 |
-
inputs=[self.components["youtube_url"]],
|
| 88 |
outputs=[self.components["import_status"]]
|
| 89 |
)
|
| 90 |
|
|
@@ -106,4 +106,8 @@ class YouTubeTab(BaseTab):
|
|
| 106 |
)
|
| 107 |
except (AttributeError, KeyError) as e:
|
| 108 |
logger.error(f"Error connecting success handler in YouTubeTab: {str(e)}")
|
| 109 |
-
# Continue without the success handler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
# YouTube download event
|
| 85 |
download_event = self.components["youtube_download_btn"].click(
|
| 86 |
+
fn=self.download_youtube_with_splitting,
|
| 87 |
+
inputs=[self.components["youtube_url"], self.components["enable_automatic_video_split"]],
|
| 88 |
outputs=[self.components["import_status"]]
|
| 89 |
)
|
| 90 |
|
|
|
|
| 106 |
)
|
| 107 |
except (AttributeError, KeyError) as e:
|
| 108 |
logger.error(f"Error connecting success handler in YouTubeTab: {str(e)}")
|
| 109 |
+
# Continue without the success handler
|
| 110 |
+
|
| 111 |
+
def download_youtube_with_splitting(self, url, enable_splitting):
|
| 112 |
+
"""Download YouTube video with splitting option"""
|
| 113 |
+
return self.app.importing.download_youtube_video(url, enable_splitting, gr.Progress())
|
vms/ui/project/tabs/preview_tab.py
CHANGED
|
@@ -200,8 +200,10 @@ class PreviewTab(BaseTab):
|
|
| 200 |
# Return just the model IDs as a list of simple strings
|
| 201 |
version_ids = list(MODEL_VERSIONS.get(internal_type, {}).keys())
|
| 202 |
logger.info(f"Found {len(version_ids)} versions for {model_type}: {version_ids}")
|
| 203 |
-
|
| 204 |
-
|
|
|
|
|
|
|
| 205 |
def get_default_model_version(self, model_type: str) -> str:
|
| 206 |
"""Get default model version for the given model type"""
|
| 207 |
# Convert UI display name to internal name
|
|
|
|
| 200 |
# Return just the model IDs as a list of simple strings
|
| 201 |
version_ids = list(MODEL_VERSIONS.get(internal_type, {}).keys())
|
| 202 |
logger.info(f"Found {len(version_ids)} versions for {model_type}: {version_ids}")
|
| 203 |
+
|
| 204 |
+
# Ensure they're all strings
|
| 205 |
+
return [str(version) for version in version_ids]
|
| 206 |
+
|
| 207 |
def get_default_model_version(self, model_type: str) -> str:
|
| 208 |
"""Get default model version for the given model type"""
|
| 209 |
# Convert UI display name to internal name
|
vms/ui/project/tabs/train_tab.py
CHANGED
|
@@ -462,12 +462,15 @@ class TrainTab(BaseTab):
|
|
| 462 |
# Update UI state with proper model_type first
|
| 463 |
self.app.update_ui_state(model_type=model_type)
|
| 464 |
|
|
|
|
|
|
|
|
|
|
| 465 |
# Create a new dropdown with the updated choices
|
| 466 |
if not model_versions:
|
| 467 |
logger.warning(f"No model versions available for {model_type}, using empty list")
|
| 468 |
# Return empty dropdown to avoid errors
|
| 469 |
return gr.Dropdown(choices=[], value=None)
|
| 470 |
-
|
| 471 |
# Ensure default_version is in model_versions
|
| 472 |
if default_version not in model_versions and model_versions:
|
| 473 |
default_version = model_versions[0]
|
|
@@ -481,8 +484,7 @@ class TrainTab(BaseTab):
|
|
| 481 |
logger.error(f"Error in update_model_versions: {str(e)}")
|
| 482 |
# Return empty dropdown to avoid errors
|
| 483 |
return gr.Dropdown(choices=[], value=None)
|
| 484 |
-
|
| 485 |
-
|
| 486 |
def handle_training_start(
|
| 487 |
self, preset, model_type, model_version, training_type,
|
| 488 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
|
@@ -561,7 +563,9 @@ class TrainTab(BaseTab):
|
|
| 561 |
# Return just the model IDs as a list of simple strings
|
| 562 |
version_ids = list(MODEL_VERSIONS.get(internal_type, {}).keys())
|
| 563 |
logger.info(f"Found {len(version_ids)} versions for {model_type}: {version_ids}")
|
| 564 |
-
|
|
|
|
|
|
|
| 565 |
|
| 566 |
def get_default_model_version(self, model_type: str) -> str:
|
| 567 |
"""Get default model version for the given model type"""
|
|
@@ -749,9 +753,6 @@ class TrainTab(BaseTab):
|
|
| 749 |
model_versions = self.get_model_version_choices(model_display_name)
|
| 750 |
default_model_version = self.get_default_model_version(model_display_name)
|
| 751 |
|
| 752 |
-
# Create the model version dropdown update
|
| 753 |
-
model_version_update = gr.Dropdown(choices=model_versions, value=default_model_version)
|
| 754 |
-
|
| 755 |
# Ensure we have valid choices and values
|
| 756 |
if not model_versions:
|
| 757 |
logger.warning(f"No versions found for {model_display_name}, using empty list")
|
|
@@ -761,6 +762,12 @@ class TrainTab(BaseTab):
|
|
| 761 |
default_model_version = model_versions[0]
|
| 762 |
logger.info(f"Reset default version to first available: {default_model_version}")
|
| 763 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 764 |
# Return values in the same order as the output components
|
| 765 |
return (
|
| 766 |
model_display_name,
|
|
|
|
| 462 |
# Update UI state with proper model_type first
|
| 463 |
self.app.update_ui_state(model_type=model_type)
|
| 464 |
|
| 465 |
+
# Ensure model_versions is a simple list of strings
|
| 466 |
+
model_versions = [str(version) for version in model_versions]
|
| 467 |
+
|
| 468 |
# Create a new dropdown with the updated choices
|
| 469 |
if not model_versions:
|
| 470 |
logger.warning(f"No model versions available for {model_type}, using empty list")
|
| 471 |
# Return empty dropdown to avoid errors
|
| 472 |
return gr.Dropdown(choices=[], value=None)
|
| 473 |
+
|
| 474 |
# Ensure default_version is in model_versions
|
| 475 |
if default_version not in model_versions and model_versions:
|
| 476 |
default_version = model_versions[0]
|
|
|
|
| 484 |
logger.error(f"Error in update_model_versions: {str(e)}")
|
| 485 |
# Return empty dropdown to avoid errors
|
| 486 |
return gr.Dropdown(choices=[], value=None)
|
| 487 |
+
|
|
|
|
| 488 |
def handle_training_start(
|
| 489 |
self, preset, model_type, model_version, training_type,
|
| 490 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
|
|
|
| 563 |
# Return just the model IDs as a list of simple strings
|
| 564 |
version_ids = list(MODEL_VERSIONS.get(internal_type, {}).keys())
|
| 565 |
logger.info(f"Found {len(version_ids)} versions for {model_type}: {version_ids}")
|
| 566 |
+
|
| 567 |
+
# Ensure they're all strings
|
| 568 |
+
return [str(version) for version in version_ids]
|
| 569 |
|
| 570 |
def get_default_model_version(self, model_type: str) -> str:
|
| 571 |
"""Get default model version for the given model type"""
|
|
|
|
| 753 |
model_versions = self.get_model_version_choices(model_display_name)
|
| 754 |
default_model_version = self.get_default_model_version(model_display_name)
|
| 755 |
|
|
|
|
|
|
|
|
|
|
| 756 |
# Ensure we have valid choices and values
|
| 757 |
if not model_versions:
|
| 758 |
logger.warning(f"No versions found for {model_display_name}, using empty list")
|
|
|
|
| 762 |
default_model_version = model_versions[0]
|
| 763 |
logger.info(f"Reset default version to first available: {default_model_version}")
|
| 764 |
|
| 765 |
+
# Ensure model_versions is a simple list of strings
|
| 766 |
+
model_versions = [str(version) for version in model_versions]
|
| 767 |
+
|
| 768 |
+
# Create the model version dropdown update
|
| 769 |
+
model_version_update = gr.Dropdown(choices=model_versions, value=default_model_version)
|
| 770 |
+
|
| 771 |
# Return values in the same order as the output components
|
| 772 |
return (
|
| 773 |
model_display_name,
|
vms/utils/webdataset_handler.py
CHANGED
|
@@ -41,7 +41,9 @@ def process_webdataset_shard(
|
|
| 41 |
"""
|
| 42 |
video_count = 0
|
| 43 |
image_count = 0
|
| 44 |
-
|
|
|
|
|
|
|
| 45 |
try:
|
| 46 |
# Dictionary to store grouped files by prefix
|
| 47 |
grouped_files = {}
|
|
|
|
| 41 |
"""
|
| 42 |
video_count = 0
|
| 43 |
image_count = 0
|
| 44 |
+
|
| 45 |
+
print(f"videos_output_dir = {videos_output_dir}")
|
| 46 |
+
print(f"staging_output_dir = {staging_output_dir}")
|
| 47 |
try:
|
| 48 |
# Dictionary to store grouped files by prefix
|
| 49 |
grouped_files = {}
|