import torch import numpy as np import random import os # 1. Set random seeds seed = 2025 torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) # 2. Disable dropout & training randomness torch.use_deterministic_algorithms(True, warn_only=True) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False from transformers.modeling_outputs import TokenClassifierOutput import torch import torch.nn as nn from transformers import PreTrainedModel, AutoModel, AutoConfig, BertConfig from torch.nn import CrossEntropyLoss from typing import Optional, Tuple, Union import logging, json, os from .configuration_stacked import ImpressoConfig logger = logging.getLogger(__name__) def get_info(label_map): num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()} return num_token_labels_dict class ExtendedMultitaskTimeModelForTokenClassification(PreTrainedModel): config_class = ImpressoConfig _keys_to_ignore_on_load_missing = [r"position_ids"] def __init__(self, config, temporal_fusion_strategy="baseline", num_years=327): super().__init__(config) self.num_token_labels_dict = get_info(config.label_map) self.config = config self.temporal_fusion_strategy = temporal_fusion_strategy self.model = AutoModel.from_pretrained( config.pretrained_config["_name_or_path"], config=config.pretrained_config ) self.model.config.use_cache = False self.model.config.pretraining_tp = 1 self.num_years = num_years classifier_dropout = getattr(config, "classifier_dropout", 0.1) or config.hidden_dropout_prob self.dropout = nn.Dropout(classifier_dropout) self.temporal_fusion = TemporalFusion(config.hidden_size, strategy=self.temporal_fusion_strategy, num_years=num_years) # Additional transformer layers self.transformer_encoder = nn.TransformerEncoder( nn.TransformerEncoderLayer( d_model=config.hidden_size, nhead=config.num_attention_heads ), num_layers=2, ) self.token_classifiers = nn.ModuleDict({ task: nn.Linear(config.hidden_size, num_labels) for task, num_labels in self.num_token_labels_dict.items() }) self.post_init() def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, token_labels: Optional[dict] = None, date_indices: Optional[torch.Tensor] = None, year_index: Optional[torch.Tensor] = None, decade_index: Optional[torch.Tensor] = None, century_index: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: inputs_embeds = self.model.embeddings(input_ids) # Early cross-attention fusion if self.temporal_fusion_strategy == "early-cross-attention": year_emb = self.temporal_fusion.compute_time_embedding(year_index) # (B, H) inputs_embeds = self.temporal_fusion.cross_attn(inputs_embeds, year_emb) bert_kwargs = { "inputs_embeds": inputs_embeds if self.temporal_fusion_strategy == "early-cross-attention" else None, "input_ids": input_ids if self.temporal_fusion_strategy != "early-cross-attention" else None, "attention_mask": attention_mask, "token_type_ids": token_type_ids, "position_ids": position_ids, "head_mask": head_mask, "output_attentions": output_attentions, "output_hidden_states": output_hidden_states, "return_dict": return_dict, } if any(keyword in self.config.name_or_path.lower() for keyword in ["llama", "deberta"]): bert_kwargs.pop("token_type_ids", None) bert_kwargs.pop("head_mask", None) outputs = self.model(**bert_kwargs) token_output = self.dropout(outputs[0]) # (B, T, H) hidden_states = list(outputs.hidden_states) if output_hidden_states else None # Pass through additional transformer layers token_output = self.transformer_encoder(token_output.transpose(0, 1)).transpose( 0, 1 ) # Apply fusion after transformer if needed if self.temporal_fusion_strategy not in ["baseline", "early-cross-attention"]: token_output = self.temporal_fusion(token_output, year_index) if output_hidden_states: hidden_states.append(token_output) # add the final fused state task_logits = {} total_loss = 0 for task, classifier in self.token_classifiers.items(): logits = classifier(token_output) task_logits[task] = logits if token_labels and task in token_labels: loss_fct = CrossEntropyLoss() loss = loss_fct( logits.view(-1, self.num_token_labels_dict[task]), token_labels[task].view(-1), ) total_loss += loss if not return_dict: output = (task_logits,) + outputs[2:] return ((total_loss,) + output) if total_loss != 0 else output return TokenClassifierOutput( loss=total_loss, logits=task_logits, hidden_states=tuple(hidden_states) if hidden_states is not None else None, attentions=outputs.attentions if output_attentions else None, ) class TemporalFusion(nn.Module): def __init__(self, hidden_size, strategy="add", num_years=327, min_year=1700): super().__init__() self.strategy = strategy self.hidden_size = hidden_size self.min_year = min_year self.max_year = min_year + num_years - 1 self.year_emb = nn.Embedding(num_years, hidden_size) if strategy == "concat": self.concat_proj = nn.Linear(hidden_size * 2, hidden_size) elif strategy == "film": self.film_gamma = nn.Linear(hidden_size, hidden_size) self.film_beta = nn.Linear(hidden_size, hidden_size) elif strategy == "adapter": self.adapter = nn.Sequential( nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size), ) elif strategy == "relative": self.relative_encoder = nn.Sequential( nn.Linear(hidden_size, hidden_size), nn.SiLU(), nn.LayerNorm(hidden_size), ) self.film_gamma = nn.Linear(hidden_size, hidden_size) self.film_beta = nn.Linear(hidden_size, hidden_size) elif strategy == "multiscale": self.decade_emb = nn.Embedding(1000, hidden_size) self.century_emb = nn.Embedding(100, hidden_size) elif strategy in ["early-cross-attention", "late-cross-attention"]: self.year_encoder = nn.Sequential( nn.Linear(hidden_size, hidden_size), nn.SiLU() ) self.cross_attn = TemporalCrossAttention(hidden_size) def compute_time_embedding(self, year_index): if self.strategy in ["early-cross-attention", "late-cross-attention"]: return self.year_encoder(self.year_emb(year_index)) elif self.strategy == "multiscale": year_index = year_index.long() year = year_index + self.min_year decade = (year // 10).long() century = (year // 100).long() return ( self.year_emb(year_index) + self.decade_emb(decade) + self.century_emb(century) ) else: return self.year_emb(year_index) def forward(self, token_output, year_index): B, T, H = token_output.size() if self.strategy == "baseline": return token_output year_emb = self.compute_time_embedding(year_index) if self.strategy == "concat": expanded_year = year_emb.unsqueeze(1).repeat(1, T, 1) fused = torch.cat([token_output, expanded_year], dim=-1) return self.concat_proj(fused) elif self.strategy == "film": gamma = self.film_gamma(year_emb).unsqueeze(1) beta = self.film_beta(year_emb).unsqueeze(1) return gamma * token_output + beta elif self.strategy == "adapter": return token_output + self.adapter(year_emb).unsqueeze(1) elif self.strategy == "add": expanded_year = year_emb.unsqueeze(1).repeat(1, T, 1) return token_output + expanded_year elif self.strategy == "relative": encoded = self.relative_encoder(year_emb) gamma = self.film_gamma(encoded).unsqueeze(1) beta = self.film_beta(encoded).unsqueeze(1) return gamma * token_output + beta elif self.strategy == "multiscale": expanded_year = year_emb.unsqueeze(1).expand(-1, T, -1) return token_output + expanded_year elif self.strategy == "late-cross-attention": return self.cross_attn(token_output, year_emb) else: raise ValueError(f"Unknown fusion strategy: {self.strategy}") class TemporalCrossAttention(nn.Module): def __init__(self, hidden_size, num_heads=4): super().__init__() self.attn = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=num_heads, batch_first=True) def forward(self, token_output, time_embedding): # token_output: (B, T, H), time_embedding: (B, H) time_as_seq = time_embedding.unsqueeze(1) # (B, 1, H) attn_output, _ = self.attn(token_output, time_as_seq, time_as_seq) return token_output + attn_output