Spaces:
Running
Running
| """ | |
| SmolLM3 Model Wrapper | |
| Handles model loading, tokenizer, and training setup | |
| """ | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| AutoConfig, | |
| TrainingArguments, | |
| Trainer | |
| ) | |
| from typing import Optional, Dict, Any | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class SmolLM3Model: | |
| """Wrapper for SmolLM3 model and tokenizer""" | |
| def __init__( | |
| self, | |
| model_name: str = "HuggingFaceTB/SmolLM3-3B", | |
| max_seq_length: int = 4096, | |
| config: Optional[Any] = None, | |
| device_map: Optional[str] = None, | |
| torch_dtype: Optional[torch.dtype] = None | |
| ): | |
| self.model_name = model_name | |
| self.max_seq_length = max_seq_length | |
| self.config = config | |
| # Set device and dtype | |
| if torch_dtype is None: | |
| if torch.cuda.is_available(): | |
| self.torch_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16 | |
| else: | |
| self.torch_dtype = torch.float32 | |
| else: | |
| self.torch_dtype = torch_dtype | |
| if device_map is None: | |
| self.device_map = "auto" if torch.cuda.is_available() else "cpu" | |
| else: | |
| self.device_map = device_map | |
| # Load tokenizer and model | |
| self._load_tokenizer() | |
| self._load_model() | |
| def _load_tokenizer(self): | |
| """Load the tokenizer""" | |
| logger.info(f"Loading tokenizer from {self.model_name}") | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_name, | |
| trust_remote_code=True, | |
| use_fast=True | |
| ) | |
| # Set pad token if not present | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| logger.info(f"Tokenizer loaded successfully. Vocab size: {self.tokenizer.vocab_size}") | |
| except Exception as e: | |
| logger.error(f"Failed to load tokenizer: {e}") | |
| raise | |
| def _load_model(self): | |
| """Load the model""" | |
| logger.info(f"Loading model from {self.model_name}") | |
| try: | |
| # Load model configuration | |
| model_config = AutoConfig.from_pretrained( | |
| self.model_name, | |
| trust_remote_code=True | |
| ) | |
| # Update configuration if needed | |
| if hasattr(model_config, 'max_position_embeddings'): | |
| model_config.max_position_embeddings = self.max_seq_length | |
| # Load model | |
| model_kwargs = { | |
| "torch_dtype": self.torch_dtype, | |
| "device_map": self.device_map, | |
| "trust_remote_code": True | |
| } | |
| # Only add flash attention if the model supports it | |
| if hasattr(self.config, 'use_flash_attention') and self.config.use_flash_attention: | |
| try: | |
| # Test if the model supports flash attention | |
| test_config = AutoConfig.from_pretrained(self.model_name, trust_remote_code=True) | |
| if hasattr(test_config, 'use_flash_attention_2'): | |
| model_kwargs["use_flash_attention_2"] = True | |
| except: | |
| # If flash attention is not supported, skip it | |
| pass | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name, | |
| config=model_config, | |
| **model_kwargs | |
| ) | |
| # Enable gradient checkpointing if specified | |
| if self.config and self.config.use_gradient_checkpointing: | |
| self.model.gradient_checkpointing_enable() | |
| logger.info(f"Model loaded successfully. Parameters: {self.model.num_parameters():,}") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| raise | |
| def get_training_arguments(self, output_dir: str, **kwargs) -> TrainingArguments: | |
| """Get training arguments for the Trainer""" | |
| if self.config is None: | |
| raise ValueError("Config is required to get training arguments") | |
| # Debug: Print config attributes to identify the issue | |
| logger.info(f"Config type: {type(self.config)}") | |
| logger.info(f"Config attributes: {[attr for attr in dir(self.config) if not attr.startswith('_')]}") | |
| # Merge config with kwargs | |
| training_args = {} | |
| # Add arguments one by one with error checking | |
| try: | |
| training_args["output_dir"] = output_dir | |
| training_args["per_device_train_batch_size"] = self.config.batch_size | |
| training_args["per_device_eval_batch_size"] = self.config.batch_size | |
| training_args["gradient_accumulation_steps"] = self.config.gradient_accumulation_steps | |
| training_args["learning_rate"] = self.config.learning_rate | |
| training_args["weight_decay"] = self.config.weight_decay | |
| training_args["warmup_steps"] = self.config.warmup_steps | |
| training_args["max_steps"] = self.config.max_iters | |
| training_args["save_steps"] = self.config.save_steps | |
| training_args["eval_steps"] = self.config.eval_steps | |
| training_args["logging_steps"] = self.config.logging_steps | |
| training_args["save_total_limit"] = self.config.save_total_limit | |
| training_args["eval_strategy"] = self.config.eval_strategy | |
| training_args["metric_for_best_model"] = self.config.metric_for_best_model | |
| training_args["greater_is_better"] = self.config.greater_is_better | |
| training_args["load_best_model_at_end"] = self.config.load_best_model_at_end | |
| training_args["fp16"] = self.config.fp16 | |
| training_args["bf16"] = self.config.bf16 | |
| training_args["ddp_backend"] = self.config.ddp_backend if torch.cuda.device_count() > 1 else None | |
| training_args["ddp_find_unused_parameters"] = self.config.ddp_find_unused_parameters if torch.cuda.device_count() > 1 else False | |
| training_args["report_to"] = None | |
| training_args["remove_unused_columns"] = False | |
| training_args["dataloader_pin_memory"] = False | |
| training_args["group_by_length"] = True | |
| training_args["length_column_name"] = "length" | |
| training_args["ignore_data_skip"] = False | |
| training_args["seed"] = 42 | |
| training_args["data_seed"] = 42 | |
| training_args["dataloader_num_workers"] = getattr(self.config, 'dataloader_num_workers', 4) | |
| training_args["max_grad_norm"] = getattr(self.config, 'max_grad_norm', 1.0) | |
| training_args["optim"] = self.config.optimizer | |
| training_args["lr_scheduler_type"] = self.config.scheduler | |
| training_args["warmup_ratio"] = 0.1 | |
| training_args["save_strategy"] = "steps" | |
| training_args["logging_strategy"] = "steps" | |
| training_args["prediction_loss_only"] = True | |
| except Exception as e: | |
| logger.error(f"Error creating training arguments: {e}") | |
| raise | |
| # Override with kwargs | |
| training_args.update(kwargs) | |
| return TrainingArguments(**training_args) | |
| def save_pretrained(self, path: str): | |
| """Save model and tokenizer""" | |
| logger.info(f"Saving model and tokenizer to {path}") | |
| os.makedirs(path, exist_ok=True) | |
| self.model.save_pretrained(path) | |
| self.tokenizer.save_pretrained(path) | |
| # Save configuration | |
| if self.config: | |
| import json | |
| config_dict = {k: v for k, v in self.config.__dict__.items() | |
| if not k.startswith('_')} | |
| with open(os.path.join(path, 'training_config.json'), 'w') as f: | |
| json.dump(config_dict, f, indent=2, default=str) | |
| def load_checkpoint(self, checkpoint_path: str): | |
| """Load model from checkpoint""" | |
| logger.info(f"Loading checkpoint from {checkpoint_path}") | |
| try: | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| checkpoint_path, | |
| torch_dtype=self.torch_dtype, | |
| device_map=self.device_map, | |
| trust_remote_code=True | |
| ) | |
| logger.info("Checkpoint loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load checkpoint: {e}") | |
| raise |