Spaces:
Running
Running
solves dataset dict issue
Browse files
data.py
CHANGED
|
@@ -190,26 +190,56 @@ class SmolLM3Dataset:
|
|
| 190 |
"length": input_length,
|
| 191 |
}
|
| 192 |
|
| 193 |
-
# Process the dataset
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
-
|
| 209 |
-
if
|
| 210 |
-
logger.info(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
-
return
|
| 213 |
|
| 214 |
def get_train_dataset(self) -> Dataset:
|
| 215 |
"""Get training dataset"""
|
|
|
|
| 190 |
"length": input_length,
|
| 191 |
}
|
| 192 |
|
| 193 |
+
# Process the dataset - handle both single dataset and dictionary of splits
|
| 194 |
+
if isinstance(self.dataset, dict):
|
| 195 |
+
# Process each split individually
|
| 196 |
+
processed_dataset = {}
|
| 197 |
+
for split_name, split_dataset in self.dataset.items():
|
| 198 |
+
logger.info(f"Processing {split_name} split...")
|
| 199 |
+
|
| 200 |
+
# Format the split
|
| 201 |
+
processed_split = split_dataset.map(
|
| 202 |
+
format_chat_template,
|
| 203 |
+
remove_columns=split_dataset.column_names,
|
| 204 |
+
desc=f"Formatting {split_name} dataset"
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# Tokenize the split
|
| 208 |
+
tokenized_split = processed_split.map(
|
| 209 |
+
tokenize_function,
|
| 210 |
+
remove_columns=processed_split.column_names,
|
| 211 |
+
desc=f"Tokenizing {split_name} dataset",
|
| 212 |
+
batched=True,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
processed_dataset[split_name] = tokenized_split
|
| 216 |
+
else:
|
| 217 |
+
# Single dataset
|
| 218 |
+
processed_dataset = self.dataset.map(
|
| 219 |
+
format_chat_template,
|
| 220 |
+
remove_columns=self.dataset.column_names,
|
| 221 |
+
desc="Formatting dataset"
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
# Tokenize the dataset
|
| 225 |
+
processed_dataset = processed_dataset.map(
|
| 226 |
+
tokenize_function,
|
| 227 |
+
remove_columns=processed_dataset.column_names,
|
| 228 |
+
desc="Tokenizing dataset",
|
| 229 |
+
batched=True,
|
| 230 |
+
)
|
| 231 |
|
| 232 |
+
# Log processing results
|
| 233 |
+
if isinstance(processed_dataset, dict):
|
| 234 |
+
logger.info(f"Dataset processed. Train samples: {len(processed_dataset['train'])}")
|
| 235 |
+
if "validation" in processed_dataset:
|
| 236 |
+
logger.info(f"Validation samples: {len(processed_dataset['validation'])}")
|
| 237 |
+
if "test" in processed_dataset:
|
| 238 |
+
logger.info(f"Test samples: {len(processed_dataset['test'])}")
|
| 239 |
+
else:
|
| 240 |
+
logger.info(f"Dataset processed. Samples: {len(processed_dataset)}")
|
| 241 |
|
| 242 |
+
return processed_dataset
|
| 243 |
|
| 244 |
def get_train_dataset(self) -> Dataset:
|
| 245 |
"""Get training dataset"""
|