Julian Bilcke
commited on
Commit
·
158671a
1
Parent(s):
4cc92e0
fix
Browse files- vms/ui/app_ui.py +1 -1
- vms/ui/project/services/training.py +180 -172
- vms/ui/project/tabs/preview_tab.py +7 -11
- vms/ui/project/tabs/train_tab.py +47 -11
vms/ui/app_ui.py
CHANGED
|
@@ -392,7 +392,7 @@ class AppUI:
|
|
| 392 |
versions = list(MODEL_VERSIONS.get(model_internal_type, {}).keys())
|
| 393 |
if versions:
|
| 394 |
model_version_val = versions[0]
|
| 395 |
-
|
| 396 |
# Ensure training_type is a valid display name
|
| 397 |
training_type_val = ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])
|
| 398 |
if training_type_val not in TRAINING_TYPES:
|
|
|
|
| 392 |
versions = list(MODEL_VERSIONS.get(model_internal_type, {}).keys())
|
| 393 |
if versions:
|
| 394 |
model_version_val = versions[0]
|
| 395 |
+
|
| 396 |
# Ensure training_type is a valid display name
|
| 397 |
training_type_val = ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])
|
| 398 |
if training_type_val not in TRAINING_TYPES:
|
vms/ui/project/services/training.py
CHANGED
|
@@ -14,6 +14,7 @@ import zipfile
|
|
| 14 |
import logging
|
| 15 |
import traceback
|
| 16 |
import threading
|
|
|
|
| 17 |
import select
|
| 18 |
|
| 19 |
from typing import Any, Optional, Dict, List, Union, Tuple
|
|
@@ -63,6 +64,8 @@ class TrainingService:
|
|
| 63 |
self.pid_file = OUTPUT_PATH / "training.pid"
|
| 64 |
self.log_file = OUTPUT_PATH / "training.log"
|
| 65 |
|
|
|
|
|
|
|
| 66 |
self.file_handler = None
|
| 67 |
self.setup_logging()
|
| 68 |
self.ensure_valid_ui_state_file()
|
|
@@ -131,67 +134,69 @@ class TrainingService:
|
|
| 131 |
"""Save current UI state to file with validation"""
|
| 132 |
ui_state_file = OUTPUT_PATH / "ui_state.json"
|
| 133 |
|
| 134 |
-
#
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
validated_values[key] = default_state[key]
|
| 173 |
-
|
| 174 |
-
try:
|
| 175 |
-
validated_values[key] = int(value)
|
| 176 |
-
except (ValueError, TypeError):
|
| 177 |
validated_values[key] = default_state[key]
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
elif key == "lora_alpha" and value not in ["16", "32", "64", "128", "256", "512", "1024"]:
|
| 181 |
-
validated_values[key] = default_state[key]
|
| 182 |
-
else:
|
| 183 |
-
validated_values[key] = value
|
| 184 |
-
|
| 185 |
-
try:
|
| 186 |
-
# First verify we can serialize to JSON
|
| 187 |
-
json_data = json.dumps(validated_values, indent=2)
|
| 188 |
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
def _backup_and_recreate_ui_state(self, ui_state_file, default_state):
|
| 197 |
"""Backup the corrupted UI state file and create a new one with defaults"""
|
|
@@ -229,130 +234,133 @@ class TrainingService:
|
|
| 229 |
"lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS
|
| 230 |
}
|
| 231 |
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
# First check if the file is empty
|
| 238 |
-
file_size = ui_state_file.stat().st_size
|
| 239 |
-
if file_size == 0:
|
| 240 |
-
logger.warning("UI state file exists but is empty, using default values")
|
| 241 |
return default_state
|
| 242 |
-
|
| 243 |
-
with open(ui_state_file, 'r') as f:
|
| 244 |
-
file_content = f.read().strip()
|
| 245 |
-
if not file_content:
|
| 246 |
-
logger.warning("UI state file is empty or contains only whitespace, using default values")
|
| 247 |
-
return default_state
|
| 248 |
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
self._backup_and_recreate_ui_state(ui_state_file, default_state)
|
| 255 |
return default_state
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
# Convert numeric values to appropriate types
|
| 263 |
-
if "train_steps" in saved_state:
|
| 264 |
-
try:
|
| 265 |
-
saved_state["train_steps"] = int(saved_state["train_steps"])
|
| 266 |
-
except (ValueError, TypeError):
|
| 267 |
-
saved_state["train_steps"] = default_state["train_steps"]
|
| 268 |
-
logger.warning("Invalid train_steps value, using default")
|
| 269 |
-
|
| 270 |
-
if "batch_size" in saved_state:
|
| 271 |
-
try:
|
| 272 |
-
saved_state["batch_size"] = int(saved_state["batch_size"])
|
| 273 |
-
except (ValueError, TypeError):
|
| 274 |
-
saved_state["batch_size"] = default_state["batch_size"]
|
| 275 |
-
logger.warning("Invalid batch_size value, using default")
|
| 276 |
-
|
| 277 |
-
if "learning_rate" in saved_state:
|
| 278 |
-
try:
|
| 279 |
-
saved_state["learning_rate"] = float(saved_state["learning_rate"])
|
| 280 |
-
except (ValueError, TypeError):
|
| 281 |
-
saved_state["learning_rate"] = default_state["learning_rate"]
|
| 282 |
-
logger.warning("Invalid learning_rate value, using default")
|
| 283 |
|
| 284 |
-
if "save_iterations" in saved_state:
|
| 285 |
try:
|
| 286 |
-
saved_state
|
| 287 |
-
except
|
| 288 |
-
|
| 289 |
-
|
|
|
|
|
|
|
| 290 |
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
break
|
| 304 |
-
# If still not found, use default
|
| 305 |
-
if not model_found:
|
| 306 |
-
merged_state["model_type"] = default_state["model_type"]
|
| 307 |
-
logger.warning(f"Invalid model type in saved state, using default")
|
| 308 |
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
#
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
logger.warning(f"Invalid training type in saved state, using default")
|
| 334 |
-
|
| 335 |
-
# Validate training_preset is in available choices
|
| 336 |
-
if merged_state["training_preset"] not in TRAINING_PRESETS:
|
| 337 |
-
merged_state["training_preset"] = default_state["training_preset"]
|
| 338 |
-
logger.warning(f"Invalid training preset in saved state, using default")
|
| 339 |
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
|
| 357 |
def ensure_valid_ui_state_file(self):
|
| 358 |
"""Ensure UI state file exists and is valid JSON"""
|
|
|
|
| 14 |
import logging
|
| 15 |
import traceback
|
| 16 |
import threading
|
| 17 |
+
import fcntl
|
| 18 |
import select
|
| 19 |
|
| 20 |
from typing import Any, Optional, Dict, List, Union, Tuple
|
|
|
|
| 64 |
self.pid_file = OUTPUT_PATH / "training.pid"
|
| 65 |
self.log_file = OUTPUT_PATH / "training.log"
|
| 66 |
|
| 67 |
+
self.file_lock = threading.Lock()
|
| 68 |
+
|
| 69 |
self.file_handler = None
|
| 70 |
self.setup_logging()
|
| 71 |
self.ensure_valid_ui_state_file()
|
|
|
|
| 134 |
"""Save current UI state to file with validation"""
|
| 135 |
ui_state_file = OUTPUT_PATH / "ui_state.json"
|
| 136 |
|
| 137 |
+
# Use a lock to prevent concurrent writes
|
| 138 |
+
with self.file_lock:
|
| 139 |
+
# Validate values before saving
|
| 140 |
+
validated_values = {}
|
| 141 |
+
default_state = {
|
| 142 |
+
"model_type": list(MODEL_TYPES.keys())[0],
|
| 143 |
+
"model_version": "",
|
| 144 |
+
"training_type": list(TRAINING_TYPES.keys())[0],
|
| 145 |
+
"lora_rank": DEFAULT_LORA_RANK_STR,
|
| 146 |
+
"lora_alpha": DEFAULT_LORA_ALPHA_STR,
|
| 147 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
| 148 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
| 149 |
+
"learning_rate": DEFAULT_LEARNING_RATE,
|
| 150 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
| 151 |
+
"training_preset": list(TRAINING_PRESETS.keys())[0],
|
| 152 |
+
"num_gpus": DEFAULT_NUM_GPUS,
|
| 153 |
+
"precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
|
| 154 |
+
"lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
# Copy default values first
|
| 158 |
+
validated_values = default_state.copy()
|
| 159 |
+
|
| 160 |
+
# Update with provided values, converting types as needed
|
| 161 |
+
for key, value in values.items():
|
| 162 |
+
if key in default_state:
|
| 163 |
+
if key == "train_steps":
|
| 164 |
+
try:
|
| 165 |
+
validated_values[key] = int(value)
|
| 166 |
+
except (ValueError, TypeError):
|
| 167 |
+
validated_values[key] = default_state[key]
|
| 168 |
+
elif key == "batch_size":
|
| 169 |
+
try:
|
| 170 |
+
validated_values[key] = int(value)
|
| 171 |
+
except (ValueError, TypeError):
|
| 172 |
+
validated_values[key] = default_state[key]
|
| 173 |
+
elif key == "learning_rate":
|
| 174 |
+
try:
|
| 175 |
+
validated_values[key] = float(value)
|
| 176 |
+
except (ValueError, TypeError):
|
| 177 |
+
validated_values[key] = default_state[key]
|
| 178 |
+
elif key == "save_iterations":
|
| 179 |
+
try:
|
| 180 |
+
validated_values[key] = int(value)
|
| 181 |
+
except (ValueError, TypeError):
|
| 182 |
+
validated_values[key] = default_state[key]
|
| 183 |
+
elif key == "lora_rank" and value not in ["16", "32", "64", "128", "256", "512", "1024"]:
|
| 184 |
validated_values[key] = default_state[key]
|
| 185 |
+
elif key == "lora_alpha" and value not in ["16", "32", "64", "128", "256", "512", "1024"]:
|
|
|
|
|
|
|
|
|
|
| 186 |
validated_values[key] = default_state[key]
|
| 187 |
+
else:
|
| 188 |
+
validated_values[key] = value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
+
try:
|
| 191 |
+
# First verify we can serialize to JSON
|
| 192 |
+
json_data = json.dumps(validated_values, indent=2)
|
| 193 |
+
|
| 194 |
+
# Write to the file
|
| 195 |
+
with open(ui_state_file, 'w') as f:
|
| 196 |
+
f.write(json_data)
|
| 197 |
+
logger.debug(f"UI state saved successfully")
|
| 198 |
+
except Exception as e:
|
| 199 |
+
logger.error(f"Error saving UI state: {str(e)}")
|
| 200 |
|
| 201 |
def _backup_and_recreate_ui_state(self, ui_state_file, default_state):
|
| 202 |
"""Backup the corrupted UI state file and create a new one with defaults"""
|
|
|
|
| 234 |
"lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS
|
| 235 |
}
|
| 236 |
|
| 237 |
+
# Use lock for reading too to avoid reading during a write
|
| 238 |
+
with self.file_lock:
|
| 239 |
+
|
| 240 |
+
if not ui_state_file.exists():
|
| 241 |
+
logger.info("UI state file does not exist, using default values")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
return default_state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
+
try:
|
| 245 |
+
# First check if the file is empty
|
| 246 |
+
file_size = ui_state_file.stat().st_size
|
| 247 |
+
if file_size == 0:
|
| 248 |
+
logger.warning("UI state file exists but is empty, using default values")
|
|
|
|
| 249 |
return default_state
|
| 250 |
+
|
| 251 |
+
with open(ui_state_file, 'r') as f:
|
| 252 |
+
file_content = f.read().strip()
|
| 253 |
+
if not file_content:
|
| 254 |
+
logger.warning("UI state file is empty or contains only whitespace, using default values")
|
| 255 |
+
return default_state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
|
|
|
|
| 257 |
try:
|
| 258 |
+
saved_state = json.loads(file_content)
|
| 259 |
+
except json.JSONDecodeError as e:
|
| 260 |
+
logger.error(f"Error parsing UI state JSON: {str(e)}")
|
| 261 |
+
# Instead of showing the error, recreate the file with defaults
|
| 262 |
+
self._backup_and_recreate_ui_state(ui_state_file, default_state)
|
| 263 |
+
return default_state
|
| 264 |
|
| 265 |
+
# Clean up model type if it contains " (LoRA)" suffix
|
| 266 |
+
if "model_type" in saved_state and " (LoRA)" in saved_state["model_type"]:
|
| 267 |
+
saved_state["model_type"] = saved_state["model_type"].replace(" (LoRA)", "")
|
| 268 |
+
logger.info(f"Removed (LoRA) suffix from saved model type: {saved_state['model_type']}")
|
| 269 |
+
|
| 270 |
+
# Convert numeric values to appropriate types
|
| 271 |
+
if "train_steps" in saved_state:
|
| 272 |
+
try:
|
| 273 |
+
saved_state["train_steps"] = int(saved_state["train_steps"])
|
| 274 |
+
except (ValueError, TypeError):
|
| 275 |
+
saved_state["train_steps"] = default_state["train_steps"]
|
| 276 |
+
logger.warning("Invalid train_steps value, using default")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
|
| 278 |
+
if "batch_size" in saved_state:
|
| 279 |
+
try:
|
| 280 |
+
saved_state["batch_size"] = int(saved_state["batch_size"])
|
| 281 |
+
except (ValueError, TypeError):
|
| 282 |
+
saved_state["batch_size"] = default_state["batch_size"]
|
| 283 |
+
logger.warning("Invalid batch_size value, using default")
|
| 284 |
+
|
| 285 |
+
if "learning_rate" in saved_state:
|
| 286 |
+
try:
|
| 287 |
+
saved_state["learning_rate"] = float(saved_state["learning_rate"])
|
| 288 |
+
except (ValueError, TypeError):
|
| 289 |
+
saved_state["learning_rate"] = default_state["learning_rate"]
|
| 290 |
+
logger.warning("Invalid learning_rate value, using default")
|
| 291 |
+
|
| 292 |
+
if "save_iterations" in saved_state:
|
| 293 |
+
try:
|
| 294 |
+
saved_state["save_iterations"] = int(saved_state["save_iterations"])
|
| 295 |
+
except (ValueError, TypeError):
|
| 296 |
+
saved_state["save_iterations"] = default_state["save_iterations"]
|
| 297 |
+
logger.warning("Invalid save_iterations value, using default")
|
| 298 |
+
|
| 299 |
+
# Make sure we have all keys (in case structure changed)
|
| 300 |
+
merged_state = default_state.copy()
|
| 301 |
+
merged_state.update({k: v for k, v in saved_state.items() if v is not None})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
|
| 303 |
+
# Validate model_type is in available choices
|
| 304 |
+
if merged_state["model_type"] not in MODEL_TYPES:
|
| 305 |
+
# Try to map from internal name
|
| 306 |
+
model_found = False
|
| 307 |
+
for display_name, internal_name in MODEL_TYPES.items():
|
| 308 |
+
if internal_name == merged_state["model_type"]:
|
| 309 |
+
merged_state["model_type"] = display_name
|
| 310 |
+
model_found = True
|
| 311 |
+
break
|
| 312 |
+
# If still not found, use default
|
| 313 |
+
if not model_found:
|
| 314 |
+
merged_state["model_type"] = default_state["model_type"]
|
| 315 |
+
logger.warning(f"Invalid model type in saved state, using default")
|
| 316 |
+
|
| 317 |
+
# Validate model_version is appropriate for model_type
|
| 318 |
+
if "model_type" in merged_state and "model_version" in merged_state:
|
| 319 |
+
model_internal_type = MODEL_TYPES.get(merged_state["model_type"])
|
| 320 |
+
if model_internal_type:
|
| 321 |
+
valid_versions = MODEL_VERSIONS.get(model_internal_type, {}).keys()
|
| 322 |
+
if merged_state["model_version"] not in valid_versions:
|
| 323 |
+
# Set to default for this model type
|
| 324 |
+
from vms.ui.project.tabs.train_tab import TrainTab
|
| 325 |
+
train_tab = TrainTab(None) # Temporary instance just for the helper method
|
| 326 |
+
merged_state["model_version"] = train_tab.get_default_model_version(saved_state["model_type"])
|
| 327 |
+
logger.warning(f"Invalid model version for {merged_state['model_type']}, using default")
|
| 328 |
|
| 329 |
+
# Validate training_type is in available choices
|
| 330 |
+
if merged_state["training_type"] not in TRAINING_TYPES:
|
| 331 |
+
# Try to map from internal name
|
| 332 |
+
training_found = False
|
| 333 |
+
for display_name, internal_name in TRAINING_TYPES.items():
|
| 334 |
+
if internal_name == merged_state["training_type"]:
|
| 335 |
+
merged_state["training_type"] = display_name
|
| 336 |
+
training_found = True
|
| 337 |
+
break
|
| 338 |
+
# If still not found, use default
|
| 339 |
+
if not training_found:
|
| 340 |
+
merged_state["training_type"] = default_state["training_type"]
|
| 341 |
+
logger.warning(f"Invalid training type in saved state, using default")
|
| 342 |
|
| 343 |
+
# Validate training_preset is in available choices
|
| 344 |
+
if merged_state["training_preset"] not in TRAINING_PRESETS:
|
| 345 |
+
merged_state["training_preset"] = default_state["training_preset"]
|
| 346 |
+
logger.warning(f"Invalid training preset in saved state, using default")
|
| 347 |
+
|
| 348 |
+
# Validate lora_rank is in allowed values
|
| 349 |
+
if merged_state.get("lora_rank") not in ["16", "32", "64", "128", "256", "512", "1024"]:
|
| 350 |
+
merged_state["lora_rank"] = default_state["lora_rank"]
|
| 351 |
+
logger.warning(f"Invalid lora_rank in saved state, using default")
|
| 352 |
+
|
| 353 |
+
# Validate lora_alpha is in allowed values
|
| 354 |
+
if merged_state.get("lora_alpha") not in ["16", "32", "64", "128", "256", "512", "1024"]:
|
| 355 |
+
merged_state["lora_alpha"] = default_state["lora_alpha"]
|
| 356 |
+
logger.warning(f"Invalid lora_alpha in saved state, using default")
|
| 357 |
+
|
| 358 |
+
return merged_state
|
| 359 |
+
except Exception as e:
|
| 360 |
+
logger.error(f"Error loading UI state: {str(e)}")
|
| 361 |
+
# If anything goes wrong, backup and recreate
|
| 362 |
+
self._backup_and_recreate_ui_state(ui_state_file, default_state)
|
| 363 |
+
return default_state
|
| 364 |
|
| 365 |
def ensure_valid_ui_state_file(self):
|
| 366 |
"""Ensure UI state file exists and is valid JSON"""
|
vms/ui/project/tabs/preview_tab.py
CHANGED
|
@@ -298,7 +298,7 @@ class PreviewTab(BaseTab):
|
|
| 298 |
# Update model_version choices when model_type changes or tab is selected
|
| 299 |
if hasattr(self.app, 'tabs_component') and self.app.tabs_component is not None:
|
| 300 |
self.app.tabs_component.select(
|
| 301 |
-
fn=self.
|
| 302 |
inputs=[],
|
| 303 |
outputs=[
|
| 304 |
self.components["model_type"],
|
|
@@ -391,7 +391,7 @@ class PreviewTab(BaseTab):
|
|
| 391 |
self.components["conditioning_image"]: gr.Image(visible=show_conditioning_image)
|
| 392 |
}
|
| 393 |
|
| 394 |
-
def
|
| 395 |
"""Sync model type with training tab when preview tab is selected and update model version choices"""
|
| 396 |
model_type = self.get_default_model_type()
|
| 397 |
model_version = ""
|
|
@@ -401,19 +401,15 @@ class PreviewTab(BaseTab):
|
|
| 401 |
preview_state = ui_state.get("preview", {})
|
| 402 |
model_version = preview_state.get("model_version", "")
|
| 403 |
|
|
|
|
| 404 |
if not model_version:
|
| 405 |
-
#
|
| 406 |
internal_type = MODEL_TYPES.get(model_type)
|
| 407 |
if internal_type and internal_type in MODEL_VERSIONS:
|
| 408 |
-
|
| 409 |
-
if
|
| 410 |
-
|
| 411 |
-
model_version = f"{first_version} - {model_version_info.get('name', '')}"
|
| 412 |
|
| 413 |
-
# If we couldn't get it, use default
|
| 414 |
-
if not model_version:
|
| 415 |
-
model_version = self.get_default_model_version(model_type)
|
| 416 |
-
|
| 417 |
return model_type, model_version
|
| 418 |
|
| 419 |
def update_resolution(self, preset: str) -> Tuple[int, int, float]:
|
|
|
|
| 298 |
# Update model_version choices when model_type changes or tab is selected
|
| 299 |
if hasattr(self.app, 'tabs_component') and self.app.tabs_component is not None:
|
| 300 |
self.app.tabs_component.select(
|
| 301 |
+
fn=self.sync_model_type_and_versions,
|
| 302 |
inputs=[],
|
| 303 |
outputs=[
|
| 304 |
self.components["model_type"],
|
|
|
|
| 391 |
self.components["conditioning_image"]: gr.Image(visible=show_conditioning_image)
|
| 392 |
}
|
| 393 |
|
| 394 |
+
def sync_model_type_and_versions(self) -> Tuple[str, str]:
|
| 395 |
"""Sync model type with training tab when preview tab is selected and update model version choices"""
|
| 396 |
model_type = self.get_default_model_type()
|
| 397 |
model_version = ""
|
|
|
|
| 401 |
preview_state = ui_state.get("preview", {})
|
| 402 |
model_version = preview_state.get("model_version", "")
|
| 403 |
|
| 404 |
+
# If no model version specified or invalid, use default
|
| 405 |
if not model_version:
|
| 406 |
+
# Get the internal model type
|
| 407 |
internal_type = MODEL_TYPES.get(model_type)
|
| 408 |
if internal_type and internal_type in MODEL_VERSIONS:
|
| 409 |
+
versions = list(MODEL_VERSIONS[internal_type].keys())
|
| 410 |
+
if versions:
|
| 411 |
+
model_version = versions[0]
|
|
|
|
| 412 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
return model_type, model_version
|
| 414 |
|
| 415 |
def update_resolution(self, preset: str) -> Tuple[int, int, float]:
|
vms/ui/project/tabs/train_tab.py
CHANGED
|
@@ -69,7 +69,7 @@ class TrainTab(BaseTab):
|
|
| 69 |
# Get model versions for the default model type
|
| 70 |
default_model_versions = self.get_model_version_choices(default_model_type)
|
| 71 |
default_model_version = self.get_default_model_version(default_model_type)
|
| 72 |
-
|
| 73 |
self.components["model_version"] = gr.Dropdown(
|
| 74 |
choices=default_model_versions,
|
| 75 |
label="Model Version",
|
|
@@ -214,6 +214,37 @@ class TrainTab(BaseTab):
|
|
| 214 |
|
| 215 |
return tab
|
| 216 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
def connect_events(self) -> None:
|
| 218 |
"""Connect event handlers to UI components"""
|
| 219 |
# Model type change event - Update model version dropdown choices
|
|
@@ -222,8 +253,8 @@ class TrainTab(BaseTab):
|
|
| 222 |
inputs=[self.components["model_type"]],
|
| 223 |
outputs=[self.components["model_version"]]
|
| 224 |
).then(
|
| 225 |
-
fn=
|
| 226 |
-
inputs=[self.components["model_type"]],
|
| 227 |
outputs=[]
|
| 228 |
).then(
|
| 229 |
# Use get_model_info instead of update_model_info
|
|
@@ -234,8 +265,8 @@ class TrainTab(BaseTab):
|
|
| 234 |
|
| 235 |
# Model version change event
|
| 236 |
self.components["model_version"].change(
|
| 237 |
-
fn=
|
| 238 |
-
inputs=[self.components["model_version"]],
|
| 239 |
outputs=[]
|
| 240 |
)
|
| 241 |
|
|
@@ -399,10 +430,13 @@ class TrainTab(BaseTab):
|
|
| 399 |
"""Update model version choices based on selected model type"""
|
| 400 |
model_versions = self.get_model_version_choices(model_type)
|
| 401 |
default_version = self.get_default_model_version(model_type)
|
|
|
|
|
|
|
|
|
|
| 402 |
|
| 403 |
# Update the model_version dropdown with new choices and default value
|
| 404 |
return gr.Dropdown(choices=model_versions, value=default_version)
|
| 405 |
-
|
| 406 |
def handle_training_start(
|
| 407 |
self, preset, model_type, model_version, training_type,
|
| 408 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
|
@@ -477,22 +511,24 @@ class TrainTab(BaseTab):
|
|
| 477 |
if not internal_type or internal_type not in MODEL_VERSIONS:
|
| 478 |
return []
|
| 479 |
|
| 480 |
-
#
|
| 481 |
-
|
| 482 |
-
|
| 483 |
|
| 484 |
def get_default_model_version(self, model_type: str) -> str:
|
| 485 |
"""Get default model version for the given model type"""
|
| 486 |
# Convert UI display name to internal name
|
| 487 |
internal_type = MODEL_TYPES.get(model_type)
|
|
|
|
| 488 |
if not internal_type or internal_type not in MODEL_VERSIONS:
|
| 489 |
return ""
|
| 490 |
|
| 491 |
# Get the first version available for this model type
|
| 492 |
versions = MODEL_VERSIONS.get(internal_type, {})
|
| 493 |
if versions:
|
| 494 |
-
|
| 495 |
-
|
|
|
|
| 496 |
return ""
|
| 497 |
|
| 498 |
def update_model_info(self, model_type: str, training_type: str) -> Dict:
|
|
|
|
| 69 |
# Get model versions for the default model type
|
| 70 |
default_model_versions = self.get_model_version_choices(default_model_type)
|
| 71 |
default_model_version = self.get_default_model_version(default_model_type)
|
| 72 |
+
print(f"default_model_version(default_model_type) = {default_model_version}")
|
| 73 |
self.components["model_version"] = gr.Dropdown(
|
| 74 |
choices=default_model_versions,
|
| 75 |
label="Model Version",
|
|
|
|
| 214 |
|
| 215 |
return tab
|
| 216 |
|
| 217 |
+
def update_model_type_and_version(self, model_type: str, model_version: str):
|
| 218 |
+
"""Update both model type and version together to keep them in sync"""
|
| 219 |
+
# Get internal model type
|
| 220 |
+
internal_type = MODEL_TYPES.get(model_type)
|
| 221 |
+
|
| 222 |
+
# Make sure model_version is valid for this model type
|
| 223 |
+
if internal_type and internal_type in MODEL_VERSIONS:
|
| 224 |
+
valid_versions = list(MODEL_VERSIONS[internal_type].keys())
|
| 225 |
+
if not model_version or model_version not in valid_versions:
|
| 226 |
+
if valid_versions:
|
| 227 |
+
model_version = valid_versions[0]
|
| 228 |
+
|
| 229 |
+
# Update UI state with both values to keep them in sync
|
| 230 |
+
self.app.update_ui_state(model_type=model_type, model_version=model_version)
|
| 231 |
+
return None
|
| 232 |
+
|
| 233 |
+
def save_model_version(self, model_type: str, model_version: str):
|
| 234 |
+
"""Save model version ensuring it's consistent with model type"""
|
| 235 |
+
internal_type = MODEL_TYPES.get(model_type)
|
| 236 |
+
|
| 237 |
+
# Verify the model_version is compatible with the current model_type
|
| 238 |
+
if internal_type and internal_type in MODEL_VERSIONS:
|
| 239 |
+
valid_versions = MODEL_VERSIONS[internal_type].keys()
|
| 240 |
+
if model_version not in valid_versions:
|
| 241 |
+
# Don't save incompatible version
|
| 242 |
+
return None
|
| 243 |
+
|
| 244 |
+
# Save the model version along with current model type to ensure consistency
|
| 245 |
+
self.app.update_ui_state(model_type=model_type, model_version=model_version)
|
| 246 |
+
return None
|
| 247 |
+
|
| 248 |
def connect_events(self) -> None:
|
| 249 |
"""Connect event handlers to UI components"""
|
| 250 |
# Model type change event - Update model version dropdown choices
|
|
|
|
| 253 |
inputs=[self.components["model_type"]],
|
| 254 |
outputs=[self.components["model_version"]]
|
| 255 |
).then(
|
| 256 |
+
fn=self.update_model_type_and_version, # Add this new function
|
| 257 |
+
inputs=[self.components["model_type"], self.components["model_version"]],
|
| 258 |
outputs=[]
|
| 259 |
).then(
|
| 260 |
# Use get_model_info instead of update_model_info
|
|
|
|
| 265 |
|
| 266 |
# Model version change event
|
| 267 |
self.components["model_version"].change(
|
| 268 |
+
fn=self.save_model_version, # Replace with this new function
|
| 269 |
+
inputs=[self.components["model_type"], self.components["model_version"]],
|
| 270 |
outputs=[]
|
| 271 |
)
|
| 272 |
|
|
|
|
| 430 |
"""Update model version choices based on selected model type"""
|
| 431 |
model_versions = self.get_model_version_choices(model_type)
|
| 432 |
default_version = self.get_default_model_version(model_type)
|
| 433 |
+
print(f"update_model_versions({model_type}): default_version = {default_version}")
|
| 434 |
+
# Update UI state with proper model_type first (add this line)
|
| 435 |
+
self.app.update_ui_state(model_type=model_type)
|
| 436 |
|
| 437 |
# Update the model_version dropdown with new choices and default value
|
| 438 |
return gr.Dropdown(choices=model_versions, value=default_version)
|
| 439 |
+
|
| 440 |
def handle_training_start(
|
| 441 |
self, preset, model_type, model_version, training_type,
|
| 442 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
|
|
|
| 511 |
if not internal_type or internal_type not in MODEL_VERSIONS:
|
| 512 |
return []
|
| 513 |
|
| 514 |
+
# Return just the model IDs without formatting
|
| 515 |
+
return list(MODEL_VERSIONS.get(internal_type, {}).keys())
|
| 516 |
+
|
| 517 |
|
| 518 |
def get_default_model_version(self, model_type: str) -> str:
|
| 519 |
"""Get default model version for the given model type"""
|
| 520 |
# Convert UI display name to internal name
|
| 521 |
internal_type = MODEL_TYPES.get(model_type)
|
| 522 |
+
print(f"get_default_model_version({model_type}) = {internal_type}")
|
| 523 |
if not internal_type or internal_type not in MODEL_VERSIONS:
|
| 524 |
return ""
|
| 525 |
|
| 526 |
# Get the first version available for this model type
|
| 527 |
versions = MODEL_VERSIONS.get(internal_type, {})
|
| 528 |
if versions:
|
| 529 |
+
model_versions = list(versions.keys())
|
| 530 |
+
if model_versions:
|
| 531 |
+
return model_versions[0]
|
| 532 |
return ""
|
| 533 |
|
| 534 |
def update_model_info(self, model_type: str, training_type: str) -> Dict:
|