Spaces:
Running
Running
| """ | |
| SmolLM3 Model Wrapper | |
| Handles model loading, tokenizer, and training setup | |
| """ | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| AutoConfig, | |
| TrainingArguments, | |
| Trainer | |
| ) | |
| from typing import Optional, Dict, Any | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class SmolLM3Model: | |
| """Wrapper for SmolLM3 model and tokenizer""" | |
| def __init__( | |
| self, | |
| model_name: str = "HuggingFaceTB/SmolLM3-3B", | |
| max_seq_length: int = 4096, | |
| config: Optional[Any] = None, | |
| device_map: Optional[str] = None, | |
| torch_dtype: Optional[torch.dtype] = None | |
| ): | |
| self.model_name = model_name | |
| self.max_seq_length = max_seq_length | |
| self.config = config | |
| # Set device and dtype | |
| if torch_dtype is None: | |
| if torch.cuda.is_available(): | |
| self.torch_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16 | |
| else: | |
| self.torch_dtype = torch.float32 | |
| else: | |
| self.torch_dtype = torch_dtype | |
| if device_map is None: | |
| self.device_map = "auto" if torch.cuda.is_available() else "cpu" | |
| else: | |
| self.device_map = device_map | |
| # Load tokenizer and model | |
| self._load_tokenizer() | |
| self._load_model() | |
| def _load_tokenizer(self): | |
| """Load the tokenizer""" | |
| logger.info(f"Loading tokenizer from {self.model_name}") | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_name, | |
| trust_remote_code=True, | |
| use_fast=True | |
| ) | |
| # Set pad token if not present | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| logger.info(f"Tokenizer loaded successfully. Vocab size: {self.tokenizer.vocab_size}") | |
| except Exception as e: | |
| logger.error(f"Failed to load tokenizer: {e}") | |
| raise | |
| def _load_model(self): | |
| """Load the model""" | |
| logger.info(f"Loading model from {self.model_name}") | |
| try: | |
| # Load model configuration | |
| model_config = AutoConfig.from_pretrained( | |
| self.model_name, | |
| trust_remote_code=True | |
| ) | |
| # Update configuration if needed | |
| if hasattr(model_config, 'max_position_embeddings'): | |
| model_config.max_position_embeddings = self.max_seq_length | |
| # Load model | |
| model_kwargs = { | |
| "torch_dtype": self.torch_dtype, | |
| "device_map": self.device_map, | |
| "trust_remote_code": True | |
| } | |
| # Only add flash attention if the model supports it | |
| if hasattr(self.config, 'use_flash_attention') and self.config.use_flash_attention: | |
| try: | |
| # Test if the model supports flash attention | |
| test_config = AutoConfig.from_pretrained(self.model_name, trust_remote_code=True) | |
| if hasattr(test_config, 'use_flash_attention_2'): | |
| model_kwargs["use_flash_attention_2"] = True | |
| except: | |
| # If flash attention is not supported, skip it | |
| pass | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name, | |
| config=model_config, | |
| **model_kwargs | |
| ) | |
| # Enable gradient checkpointing if specified | |
| if self.config and self.config.use_gradient_checkpointing: | |
| self.model.gradient_checkpointing_enable() | |
| logger.info(f"Model loaded successfully. Parameters: {self.model.num_parameters():,}") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| raise | |
| def get_training_arguments(self, output_dir: str, **kwargs) -> TrainingArguments: | |
| """Get training arguments for the Trainer""" | |
| if self.config is None: | |
| raise ValueError("Config is required to get training arguments") | |
| # Merge config with kwargs | |
| training_args = { | |
| "output_dir": output_dir, | |
| "per_device_train_batch_size": self.config.batch_size, | |
| "per_device_eval_batch_size": self.config.batch_size, | |
| "gradient_accumulation_steps": self.config.gradient_accumulation_steps, | |
| "learning_rate": self.config.learning_rate, | |
| "weight_decay": self.config.weight_decay, | |
| "warmup_steps": self.config.warmup_steps, | |
| "max_steps": self.config.max_iters, | |
| "save_steps": self.config.save_steps, | |
| "eval_steps": self.config.eval_steps, | |
| "logging_steps": self.config.logging_steps, | |
| "save_total_limit": self.config.save_total_limit, | |
| "eval_strategy": self.config.eval_strategy, | |
| "metric_for_best_model": self.config.metric_for_best_model, | |
| "greater_is_better": self.config.greater_is_better, | |
| "load_best_model_at_end": self.config.load_best_model_at_end, | |
| "fp16": self.config.fp16, | |
| "bf16": self.config.bf16, | |
| # Only enable DDP if multiple GPUs are available | |
| "ddp_backend": self.config.ddp_backend if torch.cuda.device_count() > 1 else None, | |
| "ddp_find_unused_parameters": self.config.ddp_find_unused_parameters if torch.cuda.device_count() > 1 else False, | |
| "report_to": "none", # Disable external logging | |
| "remove_unused_columns": False, | |
| "dataloader_pin_memory": False, | |
| "group_by_length": True, | |
| "length_column_name": "length", | |
| "ignore_data_skip": False, | |
| "seed": 42, | |
| "data_seed": 42, | |
| "dataloader_num_workers": 4, | |
| "max_grad_norm": 1.0, | |
| "optim": self.config.optimizer, | |
| "lr_scheduler_type": self.config.scheduler, | |
| "warmup_ratio": 0.1, | |
| "save_strategy": "steps", | |
| "logging_strategy": "steps", | |
| "prediction_loss_only": True, | |
| } | |
| # Override with kwargs | |
| training_args.update(kwargs) | |
| return TrainingArguments(**training_args) | |
| def save_pretrained(self, path: str): | |
| """Save model and tokenizer""" | |
| logger.info(f"Saving model and tokenizer to {path}") | |
| os.makedirs(path, exist_ok=True) | |
| self.model.save_pretrained(path) | |
| self.tokenizer.save_pretrained(path) | |
| # Save configuration | |
| if self.config: | |
| import json | |
| config_dict = {k: v for k, v in self.config.__dict__.items() | |
| if not k.startswith('_')} | |
| with open(os.path.join(path, 'training_config.json'), 'w') as f: | |
| json.dump(config_dict, f, indent=2, default=str) | |
| def load_checkpoint(self, checkpoint_path: str): | |
| """Load model from checkpoint""" | |
| logger.info(f"Loading checkpoint from {checkpoint_path}") | |
| try: | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| checkpoint_path, | |
| torch_dtype=self.torch_dtype, | |
| device_map=self.device_map, | |
| trust_remote_code=True | |
| ) | |
| logger.info("Checkpoint loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load checkpoint: {e}") | |
| raise |