Julian Bilcke
commited on
Commit
·
7f039e5
1
Parent(s):
55a4bb8
the HF space got corrupted, here's an attempt at salvaging it
Browse files- vms/ui/project/services/training.py +123 -23
- vms/ui/project/tabs/manage_tab.py +27 -1
vms/ui/project/services/training.py
CHANGED
|
@@ -810,9 +810,18 @@ class TrainingService:
|
|
| 810 |
|
| 811 |
# Update with resume_from_checkpoint if provided
|
| 812 |
if resume_from_checkpoint:
|
| 813 |
-
|
| 814 |
-
self.
|
| 815 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 816 |
|
| 817 |
# Common settings for both models
|
| 818 |
config.mixed_precision = DEFAULT_MIXED_PRECISION
|
|
@@ -1088,6 +1097,77 @@ class TrainingService:
|
|
| 1088 |
except:
|
| 1089 |
return False
|
| 1090 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1091 |
def recover_interrupted_training(self) -> Dict[str, Any]:
|
| 1092 |
"""Attempt to recover interrupted training
|
| 1093 |
|
|
@@ -1097,9 +1177,9 @@ class TrainingService:
|
|
| 1097 |
status = self.get_status()
|
| 1098 |
ui_updates = {}
|
| 1099 |
|
| 1100 |
-
# Check for any checkpoints, even if status doesn't indicate training
|
| 1101 |
-
|
| 1102 |
-
has_checkpoints =
|
| 1103 |
|
| 1104 |
# If status indicates training but process isn't running, or if we have checkpoints
|
| 1105 |
# and no active training process, try to recover
|
|
@@ -1145,15 +1225,13 @@ class TrainingService:
|
|
| 1145 |
}
|
| 1146 |
return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates}
|
| 1147 |
|
| 1148 |
-
#
|
| 1149 |
latest_checkpoint = None
|
| 1150 |
checkpoint_step = 0
|
| 1151 |
|
| 1152 |
-
if has_checkpoints:
|
| 1153 |
-
|
| 1154 |
-
|
| 1155 |
-
checkpoint_step = int(latest_checkpoint.name.split("_")[-1])
|
| 1156 |
-
logger.info(f"Found checkpoint at step {checkpoint_step}")
|
| 1157 |
|
| 1158 |
# both options are valid, but imho it is easier to just return "latest"
|
| 1159 |
# under the hood Finetrainers will convert ("latest") to (-1)
|
|
@@ -1480,17 +1558,20 @@ class TrainingService:
|
|
| 1480 |
self.append_log(f"Error uploading to hub: {str(e)}")
|
| 1481 |
return False
|
| 1482 |
|
| 1483 |
-
def
|
| 1484 |
-
"""Return
|
| 1485 |
|
| 1486 |
Returns:
|
| 1487 |
-
|
| 1488 |
"""
|
|
|
|
| 1489 |
|
| 1490 |
# Check if the root level file exists (this should be the primary location)
|
| 1491 |
model_output_safetensors_path = self.app.output_path / "pytorch_lora_weights.safetensors"
|
| 1492 |
if model_output_safetensors_path.exists():
|
| 1493 |
-
|
|
|
|
|
|
|
| 1494 |
|
| 1495 |
# Check in lora_weights directory
|
| 1496 |
lora_weights_dir = self.app.output_path / "lora_weights"
|
|
@@ -1503,6 +1584,9 @@ class TrainingService:
|
|
| 1503 |
latest_lora_checkpoint = max(lora_checkpoints, key=lambda x: int(x.name))
|
| 1504 |
logger.info(f"Found latest LoRA checkpoint: {latest_lora_checkpoint}")
|
| 1505 |
|
|
|
|
|
|
|
|
|
|
| 1506 |
# List contents of the latest checkpoint directory
|
| 1507 |
checkpoint_contents = list(latest_lora_checkpoint.glob("*"))
|
| 1508 |
logger.info(f"Contents of LoRA checkpoint {latest_lora_checkpoint.name}: {checkpoint_contents}")
|
|
@@ -1511,7 +1595,8 @@ class TrainingService:
|
|
| 1511 |
lora_safetensors = latest_lora_checkpoint / "pytorch_lora_weights.safetensors"
|
| 1512 |
if lora_safetensors.exists():
|
| 1513 |
logger.info(f"Found weights in latest LoRA checkpoint: {lora_safetensors}")
|
| 1514 |
-
|
|
|
|
| 1515 |
|
| 1516 |
# Also check for other common weight file names
|
| 1517 |
possible_weight_files = [
|
|
@@ -1525,24 +1610,27 @@ class TrainingService:
|
|
| 1525 |
weight_path = latest_lora_checkpoint / weight_file
|
| 1526 |
if weight_path.exists():
|
| 1527 |
logger.info(f"Found weights file {weight_file} in latest LoRA checkpoint: {weight_path}")
|
| 1528 |
-
|
|
|
|
| 1529 |
|
| 1530 |
# Check if any .safetensors files exist
|
| 1531 |
safetensors_files = list(latest_lora_checkpoint.glob("*.safetensors"))
|
| 1532 |
if safetensors_files:
|
| 1533 |
logger.info(f"Found .safetensors files in LoRA checkpoint: {safetensors_files}")
|
| 1534 |
# Return the first .safetensors file found
|
| 1535 |
-
|
|
|
|
| 1536 |
|
| 1537 |
# Fallback: check for direct safetensors file in lora_weights root
|
| 1538 |
lora_safetensors = lora_weights_dir / "pytorch_lora_weights.safetensors"
|
| 1539 |
if lora_safetensors.exists():
|
| 1540 |
logger.info(f"Found weights in lora_weights directory: {lora_safetensors}")
|
| 1541 |
-
|
|
|
|
| 1542 |
else:
|
| 1543 |
logger.info(f"pytorch_lora_weights.safetensors not found in lora_weights directory")
|
| 1544 |
|
| 1545 |
-
# If not found in root or lora_weights, log the issue
|
| 1546 |
logger.warning(f"Model weights not found at expected location: {model_output_safetensors_path}")
|
| 1547 |
logger.info(f"Checking output directory contents: {list(self.app.output_path.glob('*'))}")
|
| 1548 |
|
|
@@ -1553,6 +1641,9 @@ class TrainingService:
|
|
| 1553 |
latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1]))
|
| 1554 |
logger.info(f"Latest checkpoint directory: {latest_checkpoint}")
|
| 1555 |
|
|
|
|
|
|
|
|
|
|
| 1556 |
# Log contents of latest checkpoint
|
| 1557 |
checkpoint_contents = list(latest_checkpoint.glob("*"))
|
| 1558 |
logger.info(f"Contents of latest checkpoint {latest_checkpoint.name}: {checkpoint_contents}")
|
|
@@ -1560,11 +1651,20 @@ class TrainingService:
|
|
| 1560 |
checkpoint_weights = latest_checkpoint / "pytorch_lora_weights.safetensors"
|
| 1561 |
if checkpoint_weights.exists():
|
| 1562 |
logger.info(f"Found weights in latest checkpoint: {checkpoint_weights}")
|
| 1563 |
-
|
|
|
|
| 1564 |
else:
|
| 1565 |
logger.info(f"pytorch_lora_weights.safetensors not found in checkpoint directory")
|
| 1566 |
|
| 1567 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1568 |
|
| 1569 |
def create_training_dataset_zip(self) -> str:
|
| 1570 |
"""Create a ZIP file containing all training data
|
|
|
|
| 810 |
|
| 811 |
# Update with resume_from_checkpoint if provided
|
| 812 |
if resume_from_checkpoint:
|
| 813 |
+
# Validate checkpoints and find a valid one to resume from
|
| 814 |
+
valid_checkpoint = self.validate_and_find_valid_checkpoint()
|
| 815 |
+
if valid_checkpoint:
|
| 816 |
+
config.resume_from_checkpoint = "latest"
|
| 817 |
+
checkpoint_step = int(Path(valid_checkpoint).name.split("_")[-1])
|
| 818 |
+
self.append_log(f"Resuming from validated checkpoint at step {checkpoint_step}")
|
| 819 |
+
logger.info(f"Resuming from validated checkpoint: {valid_checkpoint}")
|
| 820 |
+
else:
|
| 821 |
+
error_msg = "No valid checkpoints found to resume from"
|
| 822 |
+
logger.error(error_msg)
|
| 823 |
+
self.append_log(error_msg)
|
| 824 |
+
return error_msg, "No valid checkpoints available"
|
| 825 |
|
| 826 |
# Common settings for both models
|
| 827 |
config.mixed_precision = DEFAULT_MIXED_PRECISION
|
|
|
|
| 1097 |
except:
|
| 1098 |
return False
|
| 1099 |
|
| 1100 |
+
def validate_and_find_valid_checkpoint(self) -> Optional[str]:
|
| 1101 |
+
"""Validate checkpoint directories and find the most recent valid one
|
| 1102 |
+
|
| 1103 |
+
Returns:
|
| 1104 |
+
Path to valid checkpoint directory or None if no valid checkpoint found
|
| 1105 |
+
"""
|
| 1106 |
+
# Find all checkpoint directories
|
| 1107 |
+
checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
| 1108 |
+
if not checkpoints:
|
| 1109 |
+
logger.info("No checkpoint directories found")
|
| 1110 |
+
return None
|
| 1111 |
+
|
| 1112 |
+
# Sort by step number in descending order (latest first)
|
| 1113 |
+
sorted_checkpoints = sorted(checkpoints, key=lambda x: int(x.name.split("_")[-1]), reverse=True)
|
| 1114 |
+
|
| 1115 |
+
corrupted_checkpoints = []
|
| 1116 |
+
|
| 1117 |
+
for checkpoint_dir in sorted_checkpoints:
|
| 1118 |
+
step_num = int(checkpoint_dir.name.split("_")[-1])
|
| 1119 |
+
logger.info(f"Validating checkpoint at step {step_num}: {checkpoint_dir}")
|
| 1120 |
+
|
| 1121 |
+
# Check if the .metadata file exists
|
| 1122 |
+
metadata_file = checkpoint_dir / ".metadata"
|
| 1123 |
+
if not metadata_file.exists():
|
| 1124 |
+
logger.warning(f"Checkpoint {checkpoint_dir.name} is corrupted: missing .metadata file")
|
| 1125 |
+
corrupted_checkpoints.append(checkpoint_dir)
|
| 1126 |
+
continue
|
| 1127 |
+
|
| 1128 |
+
# Try to read the metadata file to ensure it's not corrupted
|
| 1129 |
+
try:
|
| 1130 |
+
with open(metadata_file, 'r') as f:
|
| 1131 |
+
metadata = json.load(f)
|
| 1132 |
+
# Basic validation - metadata should have expected structure
|
| 1133 |
+
if not isinstance(metadata, dict):
|
| 1134 |
+
raise ValueError("Invalid metadata format")
|
| 1135 |
+
logger.info(f"Checkpoint {checkpoint_dir.name} is valid")
|
| 1136 |
+
|
| 1137 |
+
# Clean up any corrupted checkpoints we found before this valid one
|
| 1138 |
+
if corrupted_checkpoints:
|
| 1139 |
+
self.cleanup_corrupted_checkpoints(corrupted_checkpoints)
|
| 1140 |
+
|
| 1141 |
+
return str(checkpoint_dir)
|
| 1142 |
+
|
| 1143 |
+
except (json.JSONDecodeError, IOError, ValueError) as e:
|
| 1144 |
+
logger.warning(f"Checkpoint {checkpoint_dir.name} is corrupted: failed to read .metadata: {e}")
|
| 1145 |
+
corrupted_checkpoints.append(checkpoint_dir)
|
| 1146 |
+
continue
|
| 1147 |
+
|
| 1148 |
+
# If we reach here, all checkpoints are corrupted
|
| 1149 |
+
if corrupted_checkpoints:
|
| 1150 |
+
logger.error("All checkpoint directories are corrupted")
|
| 1151 |
+
self.cleanup_corrupted_checkpoints(corrupted_checkpoints)
|
| 1152 |
+
|
| 1153 |
+
return None
|
| 1154 |
+
|
| 1155 |
+
def cleanup_corrupted_checkpoints(self, corrupted_checkpoints: List[Path]) -> None:
|
| 1156 |
+
"""Remove corrupted checkpoint directories
|
| 1157 |
+
|
| 1158 |
+
Args:
|
| 1159 |
+
corrupted_checkpoints: List of corrupted checkpoint directory paths
|
| 1160 |
+
"""
|
| 1161 |
+
for checkpoint_dir in corrupted_checkpoints:
|
| 1162 |
+
try:
|
| 1163 |
+
step_num = int(checkpoint_dir.name.split("_")[-1])
|
| 1164 |
+
logger.info(f"Removing corrupted checkpoint at step {step_num}: {checkpoint_dir}")
|
| 1165 |
+
shutil.rmtree(checkpoint_dir)
|
| 1166 |
+
self.append_log(f"Removed corrupted checkpoint: {checkpoint_dir.name}")
|
| 1167 |
+
except Exception as e:
|
| 1168 |
+
logger.error(f"Failed to remove corrupted checkpoint {checkpoint_dir}: {e}")
|
| 1169 |
+
self.append_log(f"Failed to remove corrupted checkpoint {checkpoint_dir.name}: {e}")
|
| 1170 |
+
|
| 1171 |
def recover_interrupted_training(self) -> Dict[str, Any]:
|
| 1172 |
"""Attempt to recover interrupted training
|
| 1173 |
|
|
|
|
| 1177 |
status = self.get_status()
|
| 1178 |
ui_updates = {}
|
| 1179 |
|
| 1180 |
+
# Check for any valid checkpoints, even if status doesn't indicate training
|
| 1181 |
+
valid_checkpoint = self.validate_and_find_valid_checkpoint()
|
| 1182 |
+
has_checkpoints = valid_checkpoint is not None
|
| 1183 |
|
| 1184 |
# If status indicates training but process isn't running, or if we have checkpoints
|
| 1185 |
# and no active training process, try to recover
|
|
|
|
| 1225 |
}
|
| 1226 |
return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates}
|
| 1227 |
|
| 1228 |
+
# Use the valid checkpoint we found
|
| 1229 |
latest_checkpoint = None
|
| 1230 |
checkpoint_step = 0
|
| 1231 |
|
| 1232 |
+
if has_checkpoints and valid_checkpoint:
|
| 1233 |
+
checkpoint_step = int(Path(valid_checkpoint).name.split("_")[-1])
|
| 1234 |
+
logger.info(f"Found valid checkpoint at step {checkpoint_step}")
|
|
|
|
|
|
|
| 1235 |
|
| 1236 |
# both options are valid, but imho it is easier to just return "latest"
|
| 1237 |
# under the hood Finetrainers will convert ("latest") to (-1)
|
|
|
|
| 1558 |
self.append_log(f"Error uploading to hub: {str(e)}")
|
| 1559 |
return False
|
| 1560 |
|
| 1561 |
+
def get_model_output_info(self) -> Dict[str, Any]:
|
| 1562 |
+
"""Return info about the model safetensors including path and step count
|
| 1563 |
|
| 1564 |
Returns:
|
| 1565 |
+
Dict with 'path' (str or None) and 'steps' (int or None)
|
| 1566 |
"""
|
| 1567 |
+
result = {"path": None, "steps": None}
|
| 1568 |
|
| 1569 |
# Check if the root level file exists (this should be the primary location)
|
| 1570 |
model_output_safetensors_path = self.app.output_path / "pytorch_lora_weights.safetensors"
|
| 1571 |
if model_output_safetensors_path.exists():
|
| 1572 |
+
result["path"] = str(model_output_safetensors_path)
|
| 1573 |
+
# For root level, we can't determine steps easily, so return None
|
| 1574 |
+
return result
|
| 1575 |
|
| 1576 |
# Check in lora_weights directory
|
| 1577 |
lora_weights_dir = self.app.output_path / "lora_weights"
|
|
|
|
| 1584 |
latest_lora_checkpoint = max(lora_checkpoints, key=lambda x: int(x.name))
|
| 1585 |
logger.info(f"Found latest LoRA checkpoint: {latest_lora_checkpoint}")
|
| 1586 |
|
| 1587 |
+
# Extract step count from directory name
|
| 1588 |
+
result["steps"] = int(latest_lora_checkpoint.name)
|
| 1589 |
+
|
| 1590 |
# List contents of the latest checkpoint directory
|
| 1591 |
checkpoint_contents = list(latest_lora_checkpoint.glob("*"))
|
| 1592 |
logger.info(f"Contents of LoRA checkpoint {latest_lora_checkpoint.name}: {checkpoint_contents}")
|
|
|
|
| 1595 |
lora_safetensors = latest_lora_checkpoint / "pytorch_lora_weights.safetensors"
|
| 1596 |
if lora_safetensors.exists():
|
| 1597 |
logger.info(f"Found weights in latest LoRA checkpoint: {lora_safetensors}")
|
| 1598 |
+
result["path"] = str(lora_safetensors)
|
| 1599 |
+
return result
|
| 1600 |
|
| 1601 |
# Also check for other common weight file names
|
| 1602 |
possible_weight_files = [
|
|
|
|
| 1610 |
weight_path = latest_lora_checkpoint / weight_file
|
| 1611 |
if weight_path.exists():
|
| 1612 |
logger.info(f"Found weights file {weight_file} in latest LoRA checkpoint: {weight_path}")
|
| 1613 |
+
result["path"] = str(weight_path)
|
| 1614 |
+
return result
|
| 1615 |
|
| 1616 |
# Check if any .safetensors files exist
|
| 1617 |
safetensors_files = list(latest_lora_checkpoint.glob("*.safetensors"))
|
| 1618 |
if safetensors_files:
|
| 1619 |
logger.info(f"Found .safetensors files in LoRA checkpoint: {safetensors_files}")
|
| 1620 |
# Return the first .safetensors file found
|
| 1621 |
+
result["path"] = str(safetensors_files[0])
|
| 1622 |
+
return result
|
| 1623 |
|
| 1624 |
# Fallback: check for direct safetensors file in lora_weights root
|
| 1625 |
lora_safetensors = lora_weights_dir / "pytorch_lora_weights.safetensors"
|
| 1626 |
if lora_safetensors.exists():
|
| 1627 |
logger.info(f"Found weights in lora_weights directory: {lora_safetensors}")
|
| 1628 |
+
result["path"] = str(lora_safetensors)
|
| 1629 |
+
return result
|
| 1630 |
else:
|
| 1631 |
logger.info(f"pytorch_lora_weights.safetensors not found in lora_weights directory")
|
| 1632 |
|
| 1633 |
+
# If not found in root or lora_weights, log the issue and check fallback
|
| 1634 |
logger.warning(f"Model weights not found at expected location: {model_output_safetensors_path}")
|
| 1635 |
logger.info(f"Checking output directory contents: {list(self.app.output_path.glob('*'))}")
|
| 1636 |
|
|
|
|
| 1641 |
latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1]))
|
| 1642 |
logger.info(f"Latest checkpoint directory: {latest_checkpoint}")
|
| 1643 |
|
| 1644 |
+
# Extract step count from checkpoint directory name
|
| 1645 |
+
result["steps"] = int(latest_checkpoint.name.split("_")[-1])
|
| 1646 |
+
|
| 1647 |
# Log contents of latest checkpoint
|
| 1648 |
checkpoint_contents = list(latest_checkpoint.glob("*"))
|
| 1649 |
logger.info(f"Contents of latest checkpoint {latest_checkpoint.name}: {checkpoint_contents}")
|
|
|
|
| 1651 |
checkpoint_weights = latest_checkpoint / "pytorch_lora_weights.safetensors"
|
| 1652 |
if checkpoint_weights.exists():
|
| 1653 |
logger.info(f"Found weights in latest checkpoint: {checkpoint_weights}")
|
| 1654 |
+
result["path"] = str(checkpoint_weights)
|
| 1655 |
+
return result
|
| 1656 |
else:
|
| 1657 |
logger.info(f"pytorch_lora_weights.safetensors not found in checkpoint directory")
|
| 1658 |
|
| 1659 |
+
return result
|
| 1660 |
+
|
| 1661 |
+
def get_model_output_safetensors(self) -> Optional[str]:
|
| 1662 |
+
"""Return the path to the model safetensors
|
| 1663 |
+
|
| 1664 |
+
Returns:
|
| 1665 |
+
Path to safetensors file or None if not found
|
| 1666 |
+
"""
|
| 1667 |
+
return self.get_model_output_info()["path"]
|
| 1668 |
|
| 1669 |
def create_training_dataset_zip(self) -> str:
|
| 1670 |
"""Create a ZIP file containing all training data
|
vms/ui/project/tabs/manage_tab.py
CHANGED
|
@@ -25,6 +25,32 @@ class ManageTab(BaseTab):
|
|
| 25 |
self.id = "manage_tab"
|
| 26 |
self.title = "5️⃣ Storage"
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
def create(self, parent=None) -> gr.TabItem:
|
| 29 |
"""Create the Manage tab UI components"""
|
| 30 |
with gr.TabItem(self.title, id=self.id) as tab:
|
|
@@ -45,7 +71,7 @@ class ManageTab(BaseTab):
|
|
| 45 |
gr.Markdown("📦 Training dataset download disabled for large datasets")
|
| 46 |
|
| 47 |
self.components["download_model_btn"] = gr.DownloadButton(
|
| 48 |
-
|
| 49 |
variant="secondary",
|
| 50 |
size="lg"
|
| 51 |
)
|
|
|
|
| 25 |
self.id = "manage_tab"
|
| 26 |
self.title = "5️⃣ Storage"
|
| 27 |
|
| 28 |
+
def get_download_button_text(self) -> str:
|
| 29 |
+
"""Get the dynamic text for the download button based on current model state"""
|
| 30 |
+
try:
|
| 31 |
+
model_info = self.app.training.get_model_output_info()
|
| 32 |
+
if model_info["path"] and model_info["steps"]:
|
| 33 |
+
return f"🧠 Download weights ({model_info['steps']} steps)"
|
| 34 |
+
elif model_info["path"]:
|
| 35 |
+
return "🧠 Download weights (.safetensors)"
|
| 36 |
+
else:
|
| 37 |
+
return "🧠 Download weights (not available)"
|
| 38 |
+
except Exception as e:
|
| 39 |
+
logger.warning(f"Error getting model info for button text: {e}")
|
| 40 |
+
return "🧠 Download weights (.safetensors)"
|
| 41 |
+
|
| 42 |
+
def update_download_button_text(self) -> gr.update:
|
| 43 |
+
"""Update the download button text"""
|
| 44 |
+
return gr.update(value=self.get_download_button_text())
|
| 45 |
+
|
| 46 |
+
def download_and_update_button(self):
|
| 47 |
+
"""Handle download and return updated button with current text"""
|
| 48 |
+
# Get the safetensors path for download
|
| 49 |
+
path = self.app.training.get_model_output_safetensors()
|
| 50 |
+
# For DownloadButton, we need to return the file path directly for download
|
| 51 |
+
# The button text will be updated on next render
|
| 52 |
+
return path
|
| 53 |
+
|
| 54 |
def create(self, parent=None) -> gr.TabItem:
|
| 55 |
"""Create the Manage tab UI components"""
|
| 56 |
with gr.TabItem(self.title, id=self.id) as tab:
|
|
|
|
| 71 |
gr.Markdown("📦 Training dataset download disabled for large datasets")
|
| 72 |
|
| 73 |
self.components["download_model_btn"] = gr.DownloadButton(
|
| 74 |
+
self.get_download_button_text(),
|
| 75 |
variant="secondary",
|
| 76 |
size="lg"
|
| 77 |
)
|