Spaces:
Running
Running
| import os | |
| import torch | |
| from torch import nn | |
| from tqdm import tqdm | |
| import numpy as np | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator | |
| from transformers import get_scheduler | |
| from accelerate import Accelerator | |
| from accelerate.utils import set_seed | |
| from collections import defaultdict | |
| from torch.utils.data import DataLoader | |
| import torch.optim as optim | |
| from ..utils.data_utils import load_lm_dataset, extract_new_words_from_dataset, get_group_texts_func, get_tokenize_func | |
| class EmbeddingCalibrator(nn.Module): | |
| def __init__(self, hidden_size, lora_r=None, lora_alpha=None, dtype=torch.bfloat16): | |
| super().__init__() | |
| self.use_lora = lora_r is not None | |
| if not self.use_lora: | |
| self.weight = nn.Parameter(torch.zeros(hidden_size, hidden_size, dtype=dtype)) | |
| else: | |
| self.lora_scaling = lora_alpha / lora_r if lora_alpha is not None else 1.0 | |
| self.lora_A = nn.Parameter(torch.randn(lora_rank, hidden_size, dtype=dtype) * (1/lora_r)) | |
| self.lora_B = nn.Parameter(torch.zeros(hidden_size, lora_rank, dtype=dtype)) | |
| def forward(self, x): | |
| if not self.use_lora: | |
| return x + torch.matmul(x, self.weight.t()) | |
| else: | |
| # Low-rank adaptation | |
| lora_out = torch.matmul(x, self.lora_A.t()) | |
| lora_out = torch.matmul(lora_out, self.lora_B.t()) | |
| return x + self.lora_scaling * lora_out | |
| class CalibrationModel(nn.Module): | |
| def __init__( | |
| self, | |
| base_model, lm_head, original_vocab_size, num_new_tokens, | |
| calibrate_embedding=True, calibrate_lm_head=True, empty_init=False, | |
| lora_alpha=None, lora_r=None, | |
| target_loss_weight=0.15, subsequent_loss_weight=0.15, | |
| ): | |
| super().__init__() | |
| self.base_model = base_model | |
| self.lm_head = lm_head | |
| self.new_tokens_start = original_vocab_size | |
| self.new_tokens_end = original_vocab_size + num_new_tokens | |
| self.calibrate_lm_head = calibrate_lm_head | |
| self.calibrate_embedding = calibrate_embedding | |
| if not empty_init: | |
| self.lm_head_calibrator = EmbeddingCalibrator(base_model.config.hidden_size, lora_r, lora_alpha) | |
| self.embedding_calibrator = EmbeddingCalibrator(base_model.config.hidden_size, lora_r, lora_alpha) | |
| self.loss_fct = nn.CrossEntropyLoss(reduction="none") | |
| self.subsequent_tokens_loss_alpha = subsequent_loss_weight | |
| self.new_tokens_loss_alpha = target_loss_weight | |
| self.original_tokens_loss_alpha = 1 - self.new_tokens_loss_alpha - self.subsequent_tokens_loss_alpha | |
| def forward(self, input_ids, labels, attention_mask=None): | |
| # shift labels by 1 for CLM | |
| labels = labels[:, 1:].contiguous() | |
| input_ids = input_ids[:, :-1].contiguous() | |
| if self.calibrate_embedding: | |
| E_weights = self.base_model.get_input_embeddings().weight.data | |
| E_weights = torch.cat((E_weights[:self.new_tokens_start], self.embedding_calibrator(E_weights[self.new_tokens_start:]))) | |
| input_embeddings = E_weights[input_ids] | |
| if attention_mask is None: | |
| attention_mask = torch.ones_like(input_ids, dtype=torch.long) | |
| outputs = self.base_model(inputs_embeds=input_embeddings, attention_mask=attention_mask) | |
| else: | |
| with torch.no_grad(): | |
| # Forward pass through the base model | |
| outputs = self.base_model(input_ids, attention_mask=attention_mask) | |
| if self.calibrate_lm_head: | |
| with torch.no_grad(): | |
| lm_head_weights = self.lm_head.weight | |
| normed_weights = lm_head_weights.clone() | |
| normed_weights[self.new_tokens_start:self.new_tokens_end] = self.lm_head_calibrator(lm_head_weights[self.new_tokens_start:self.new_tokens_end]) | |
| logits = torch.matmul(outputs['last_hidden_state'], normed_weights.T) | |
| else: | |
| if self.calibrate_embedding: | |
| logits = self.lm_head(outputs['last_hidden_state']) | |
| else: | |
| with torch.no_grad(): | |
| logits = self.lm_head(outputs['last_hidden_state']) | |
| per_example_loss = self.loss_fct(logits.transpose(1,2), labels) | |
| original_tokens_mask = labels < self.new_tokens_start | |
| new_tokens_mask = ~original_tokens_mask | |
| loss = 0.0 | |
| if self.original_tokens_loss_alpha > 0.0: | |
| loss += self.original_tokens_loss_alpha * per_example_loss[original_tokens_mask].mean() | |
| if self.new_tokens_loss_alpha > 0.0: | |
| loss += self.new_tokens_loss_alpha * per_example_loss[new_tokens_mask].mean() | |
| if self.subsequent_tokens_loss_alpha > 0.0: | |
| subsequent_tokens_mask = torch.zeros_like(original_tokens_mask, dtype=torch.bool) | |
| subsequent_tokens_mask[:, 1:][new_tokens_mask[:, :-1]] = True | |
| loss += self.subsequent_tokens_loss_alpha * per_example_loss[subsequent_tokens_mask].mean() | |
| return {'loss': loss, 'logits': logits} | |
| def get_calibrators(self): | |
| embedding_calibrator = self.embedding_calibrator if self.calibrate_embedding else None | |
| lm_head_calibrator = self.lm_head_calibrator if self.calibrate_lm_head else None | |
| return { | |
| "embedding_calibrator": embedding_calibrator, | |
| "lm_head_calibrator": lm_head_calibrator, | |
| "new_tokens_start": self.new_tokens_start, | |
| "new_tokens_end": self.new_tokens_end, | |
| } | |
| def set_calibrators(self, embedding_calibrator=None, lm_head_calibrator=None): | |
| self.embedding_calibrator = embedding_calibrator | |
| self.lm_head_calibrator = lm_head_calibrator | |
| def save_calibrators(self, save_dir): | |
| os.makedirs(save_dir, exist_ok=True) | |
| if self.calibrate_embedding: | |
| torch.save(self.embedding_calibrator, os.path.join(save_dir, "embedding_calibrator.pt")) | |
| if self.calibrate_lm_head: | |
| torch.save(self.lm_head_calibrator, os.path.join(save_dir, "lm_head_calibrator.pt")) | |
| def load_calibrators(self, load_dir, fail_ok=False): | |
| """Loads the model's state dictionary from a file.""" | |
| try: | |
| if self.calibrate_embedding: | |
| self.embedding_calibrator = torch.load(os.path.join(load_dir, "embedding_calibrator.pt")) | |
| if self.calibrate_lm_head: | |
| self.lm_head_calibrator = torch.load(os.path.join(load_dir, "lm_head_calibrator.pt")) | |
| return True | |
| except: | |
| if fail_ok: | |
| return False | |
| raise FileNotFoundError(f"Loading calibrators from '{load_dir}' failed") | |
| def get_calibration_model(model, original_vocab_size, num_new_tokens, target_loss_weight=0.15, subsequent_loss_weight=0.15): | |
| calibrated_model = CalibrationModel(model.model, model.lm_head, original_vocab_size, num_new_tokens, target_loss_weight=target_loss_weight, subsequent_loss_weight=subsequent_loss_weight) | |
| calibrated_model.base_model.eval() | |
| calibrated_model.lm_head.eval() | |
| for param in calibrated_model.base_model.parameters(): | |
| param.requires_grad = False | |
| for param in calibrated_model.lm_head.parameters(): | |
| param.requires_grad = False | |
| for param in calibrated_model.lm_head_calibrator.parameters(): | |
| param.requires_grad = True | |
| for param in calibrated_model.embedding_calibrator.parameters(): | |
| param.requires_grad = True | |
| return calibrated_model | |
| def train_calibration_model(calibrated_model: CalibrationModel, tokenizer, dataset, save_dir=None, max_samples=None, filter_examples_without_new_tokens=True, lr=1e-4, lr_schedule="linear", num_epochs=1, batch_size=8, max_length=256, n_warmup_steps=0, text_col_name="text", clip_grad_norm=1.0, mixed_precision=None): | |
| accelerator = Accelerator(mixed_precision=mixed_precision) | |
| # Optimizer | |
| optimizer = optim.AdamW(calibrated_model.parameters(), lr=lr) | |
| # Tokenize data | |
| if tokenizer.bos_token is not None and max_length: | |
| add_start_token = True | |
| # leave room for <BOS> token to be added: | |
| max_tokenized_len = max_length - 1 | |
| else: | |
| add_start_token = False | |
| max_tokenized_len = max_length | |
| def _add_start_token(batch): | |
| bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * batch["input_ids"].size(dim=0)).to(batch["input_ids"].device) | |
| batch["input_ids"] = torch.cat([bos_tokens_tensor, batch["input_ids"]], dim=1) | |
| batch["attention_mask"] = torch.cat( | |
| [torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(batch["attention_mask"].device), batch["attention_mask"]], dim=1) | |
| return batch | |
| tokenize_function = get_tokenize_func(tokenizer, text_col_name) | |
| column_names = dataset.column_names | |
| with accelerator.main_process_first(): | |
| tokenized_dataset = dataset.map( | |
| tokenize_function, | |
| batched=True, | |
| remove_columns=column_names, | |
| load_from_cache_file=False, | |
| desc="Running tokenizer on dataset", | |
| ) | |
| group_texts = get_group_texts_func(block_size=max_tokenized_len) | |
| lm_dataset = tokenized_dataset.map( | |
| group_texts, | |
| batched=True, | |
| ) | |
| if filter_examples_without_new_tokens: | |
| examples_w_new_token = np.arange(len(lm_dataset))[np.any(np.array(lm_dataset['input_ids']) >= calibrated_model.new_tokens_start, axis=1)] | |
| lm_dataset = lm_dataset.select(examples_w_new_token) | |
| if max_samples is not None: | |
| lm_dataset = lm_dataset.select(np.arange(max_samples)) | |
| data_collator = default_data_collator | |
| # Create data loaders | |
| dataloader = DataLoader( | |
| lm_dataset, collate_fn=data_collator, batch_size=batch_size, drop_last=True, shuffle=True, | |
| ) | |
| # Learning rate scheduler | |
| if isinstance(n_warmup_steps, float): | |
| n_warmup_steps = n_warmup_steps * len(dataloader) | |
| scheduler = get_scheduler(lr_schedule, optimizer=optimizer, num_warmup_steps=n_warmup_steps, num_training_steps=len(dataloader) * num_epochs) | |
| calibrated_model, dataloader = accelerator.prepare(calibrated_model, dataloader) | |
| # Freeze the original lm_head weights | |
| for param in calibrated_model.lm_head.parameters(): | |
| param.requires_grad = False | |
| calibrated_model.train() | |
| for epoch in tqdm(range(num_epochs), unit="epochs", desc="Fitting calibration"): | |
| total_loss = 0.0 | |
| for step, batch in tqdm(enumerate(dataloader), total=len(dataloader), miniters=10, unit="batches"): | |
| if add_start_token: | |
| batch = _add_start_token(batch) | |
| batch["labels"] = batch["input_ids"] | |
| optimizer.zero_grad() | |
| outputs = calibrated_model(**batch) | |
| loss = outputs['loss'] | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(calibrated_model.parameters(), max_norm=clip_grad_norm) | |
| optimizer.step() | |
| scheduler.step() | |
| total_loss += loss.item() | |
| # # Log loss | |
| # if step % 10 == 0: | |
| # print(f"Epoch {epoch + 1}, Step {step}, Loss: {loss.item()}") | |
| avg_loss = total_loss / len(dataloader) | |
| print(f"Epoch {epoch + 1} completed. Average Loss: {avg_loss}") | |
| if save_dir is not None: | |
| calibrated_model.save_calibrators(save_dir) | |
| return calibrated_model | |
| def merge_calibrators_to_hf_model(hf_model, new_tokens_start, new_tokens_end=None, embedding_calibrator=None, lm_head_calibrator=None): | |
| embedding_calibrator.to(hf_model.device) | |
| lm_head_calibrator.to(hf_model.device) | |
| if embedding_calibrator is not None: | |
| embedding_weights = hf_model.get_input_embeddings().weight | |
| with torch.no_grad(): | |
| calibrated_weights = embedding_calibrator(embedding_weights[new_tokens_start:new_tokens_end]) | |
| hf_model.model.embed_tokens.weight.data[ | |
| new_tokens_start:new_tokens_end] = calibrated_weights | |
| if lm_head_calibrator is not None: | |
| lm_head_weights = hf_model.get_output_embeddings().weight | |
| with torch.no_grad(): | |
| calibrated_weights = lm_head_calibrator(lm_head_weights[new_tokens_start:new_tokens_end]) | |
| hf_model.lm_head.weight.data[new_tokens_start:new_tokens_end] = calibrated_weights | |
| return hf_model | |
| def merge_calibration_model_to_hf_model(hf_model, calibrated_model): | |
| calibrated_model.to(hf_model.device) | |
| if calibrated_model.calibrate_lm_head: | |
| lm_head_weights = calibrated_model.lm_head.weight | |
| normed_weights = calibrated_model.lm_head_calibrator(lm_head_weights[calibrated_model.new_tokens_start:calibrated_model.new_tokens_end]) | |
| with torch.no_grad(): | |
| hf_model.lm_head.weight.data[calibrated_model.new_tokens_start:calibrated_model.new_tokens_end] = normed_weights | |
| if calibrated_model.calibrate_embedding: | |
| embedding_weights = calibrated_model.base_model.get_input_embeddings().weight | |
| normed_weights = calibrated_model.embedding_calibrator(embedding_weights[calibrated_model.new_tokens_start:calibrated_model.new_tokens_end]) | |
| with torch.no_grad(): | |
| hf_model.model.embed_tokens.weight.data[calibrated_model.new_tokens_start:calibrated_model.new_tokens_end] = normed_weights | |
| return hf_model | |