emanuelaboros's picture
tentative to add random seeds
f8db31b
import torch
import numpy as np
import random
import os
# 1. Set random seeds
seed = 2025
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
# 2. Disable dropout & training randomness
torch.use_deterministic_algorithms(True, warn_only=True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
from transformers.modeling_outputs import TokenClassifierOutput
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModel, AutoConfig, BertConfig
from torch.nn import CrossEntropyLoss
from typing import Optional, Tuple, Union
import logging, json, os
from .configuration_stacked import ImpressoConfig
logger = logging.getLogger(__name__)
def get_info(label_map):
num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()}
return num_token_labels_dict
class ExtendedMultitaskTimeModelForTokenClassification(PreTrainedModel):
config_class = ImpressoConfig
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config, temporal_fusion_strategy="baseline", num_years=327):
super().__init__(config)
self.num_token_labels_dict = get_info(config.label_map)
self.config = config
self.temporal_fusion_strategy = temporal_fusion_strategy
self.model = AutoModel.from_pretrained(
config.pretrained_config["_name_or_path"], config=config.pretrained_config
)
self.model.config.use_cache = False
self.model.config.pretraining_tp = 1
self.num_years = num_years
classifier_dropout = getattr(config, "classifier_dropout", 0.1) or config.hidden_dropout_prob
self.dropout = nn.Dropout(classifier_dropout)
self.temporal_fusion = TemporalFusion(config.hidden_size, strategy=self.temporal_fusion_strategy,
num_years=num_years)
# Additional transformer layers
self.transformer_encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=config.hidden_size, nhead=config.num_attention_heads
),
num_layers=2,
)
self.token_classifiers = nn.ModuleDict({
task: nn.Linear(config.hidden_size, num_labels)
for task, num_labels in self.num_token_labels_dict.items()
})
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
token_labels: Optional[dict] = None,
date_indices: Optional[torch.Tensor] = None,
year_index: Optional[torch.Tensor] = None,
decade_index: Optional[torch.Tensor] = None,
century_index: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.model.embeddings(input_ids)
# Early cross-attention fusion
if self.temporal_fusion_strategy == "early-cross-attention":
year_emb = self.temporal_fusion.compute_time_embedding(year_index) # (B, H)
inputs_embeds = self.temporal_fusion.cross_attn(inputs_embeds, year_emb)
bert_kwargs = {
"inputs_embeds": inputs_embeds if self.temporal_fusion_strategy == "early-cross-attention" else None,
"input_ids": input_ids if self.temporal_fusion_strategy != "early-cross-attention" else None,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"position_ids": position_ids,
"head_mask": head_mask,
"output_attentions": output_attentions,
"output_hidden_states": output_hidden_states,
"return_dict": return_dict,
}
if any(keyword in self.config.name_or_path.lower() for keyword in ["llama", "deberta"]):
bert_kwargs.pop("token_type_ids", None)
bert_kwargs.pop("head_mask", None)
outputs = self.model(**bert_kwargs)
token_output = self.dropout(outputs[0]) # (B, T, H)
hidden_states = list(outputs.hidden_states) if output_hidden_states else None
# Pass through additional transformer layers
token_output = self.transformer_encoder(token_output.transpose(0, 1)).transpose(
0, 1
)
# Apply fusion after transformer if needed
if self.temporal_fusion_strategy not in ["baseline", "early-cross-attention"]:
token_output = self.temporal_fusion(token_output, year_index)
if output_hidden_states:
hidden_states.append(token_output) # add the final fused state
task_logits = {}
total_loss = 0
for task, classifier in self.token_classifiers.items():
logits = classifier(token_output)
task_logits[task] = logits
if token_labels and task in token_labels:
loss_fct = CrossEntropyLoss()
loss = loss_fct(
logits.view(-1, self.num_token_labels_dict[task]),
token_labels[task].view(-1),
)
total_loss += loss
if not return_dict:
output = (task_logits,) + outputs[2:]
return ((total_loss,) + output) if total_loss != 0 else output
return TokenClassifierOutput(
loss=total_loss,
logits=task_logits,
hidden_states=tuple(hidden_states) if hidden_states is not None else None,
attentions=outputs.attentions if output_attentions else None,
)
class TemporalFusion(nn.Module):
def __init__(self, hidden_size, strategy="add", num_years=327, min_year=1700):
super().__init__()
self.strategy = strategy
self.hidden_size = hidden_size
self.min_year = min_year
self.max_year = min_year + num_years - 1
self.year_emb = nn.Embedding(num_years, hidden_size)
if strategy == "concat":
self.concat_proj = nn.Linear(hidden_size * 2, hidden_size)
elif strategy == "film":
self.film_gamma = nn.Linear(hidden_size, hidden_size)
self.film_beta = nn.Linear(hidden_size, hidden_size)
elif strategy == "adapter":
self.adapter = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
)
elif strategy == "relative":
self.relative_encoder = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.SiLU(),
nn.LayerNorm(hidden_size),
)
self.film_gamma = nn.Linear(hidden_size, hidden_size)
self.film_beta = nn.Linear(hidden_size, hidden_size)
elif strategy == "multiscale":
self.decade_emb = nn.Embedding(1000, hidden_size)
self.century_emb = nn.Embedding(100, hidden_size)
elif strategy in ["early-cross-attention", "late-cross-attention"]:
self.year_encoder = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.SiLU()
)
self.cross_attn = TemporalCrossAttention(hidden_size)
def compute_time_embedding(self, year_index):
if self.strategy in ["early-cross-attention", "late-cross-attention"]:
return self.year_encoder(self.year_emb(year_index))
elif self.strategy == "multiscale":
year_index = year_index.long()
year = year_index + self.min_year
decade = (year // 10).long()
century = (year // 100).long()
return (
self.year_emb(year_index) +
self.decade_emb(decade) +
self.century_emb(century)
)
else:
return self.year_emb(year_index)
def forward(self, token_output, year_index):
B, T, H = token_output.size()
if self.strategy == "baseline":
return token_output
year_emb = self.compute_time_embedding(year_index)
if self.strategy == "concat":
expanded_year = year_emb.unsqueeze(1).repeat(1, T, 1)
fused = torch.cat([token_output, expanded_year], dim=-1)
return self.concat_proj(fused)
elif self.strategy == "film":
gamma = self.film_gamma(year_emb).unsqueeze(1)
beta = self.film_beta(year_emb).unsqueeze(1)
return gamma * token_output + beta
elif self.strategy == "adapter":
return token_output + self.adapter(year_emb).unsqueeze(1)
elif self.strategy == "add":
expanded_year = year_emb.unsqueeze(1).repeat(1, T, 1)
return token_output + expanded_year
elif self.strategy == "relative":
encoded = self.relative_encoder(year_emb)
gamma = self.film_gamma(encoded).unsqueeze(1)
beta = self.film_beta(encoded).unsqueeze(1)
return gamma * token_output + beta
elif self.strategy == "multiscale":
expanded_year = year_emb.unsqueeze(1).expand(-1, T, -1)
return token_output + expanded_year
elif self.strategy == "late-cross-attention":
return self.cross_attn(token_output, year_emb)
else:
raise ValueError(f"Unknown fusion strategy: {self.strategy}")
class TemporalCrossAttention(nn.Module):
def __init__(self, hidden_size, num_heads=4):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=num_heads, batch_first=True)
def forward(self, token_output, time_embedding):
# token_output: (B, T, H), time_embedding: (B, H)
time_as_seq = time_embedding.unsqueeze(1) # (B, 1, H)
attn_output, _ = self.attn(token_output, time_as_seq, time_as_seq)
return token_output + attn_output