|
|
import torch |
|
|
import numpy as np |
|
|
import random |
|
|
import os |
|
|
|
|
|
|
|
|
seed = 2025 |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
np.random.seed(seed) |
|
|
random.seed(seed) |
|
|
os.environ["PYTHONHASHSEED"] = str(seed) |
|
|
|
|
|
|
|
|
torch.use_deterministic_algorithms(True, warn_only=True) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
from transformers.modeling_outputs import TokenClassifierOutput |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import PreTrainedModel, AutoModel, AutoConfig, BertConfig |
|
|
from torch.nn import CrossEntropyLoss |
|
|
from typing import Optional, Tuple, Union |
|
|
import logging, json, os |
|
|
|
|
|
from .configuration_stacked import ImpressoConfig |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def get_info(label_map): |
|
|
num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()} |
|
|
return num_token_labels_dict |
|
|
|
|
|
|
|
|
class ExtendedMultitaskTimeModelForTokenClassification(PreTrainedModel): |
|
|
config_class = ImpressoConfig |
|
|
_keys_to_ignore_on_load_missing = [r"position_ids"] |
|
|
|
|
|
def __init__(self, config, temporal_fusion_strategy="baseline", num_years=327): |
|
|
super().__init__(config) |
|
|
self.num_token_labels_dict = get_info(config.label_map) |
|
|
self.config = config |
|
|
self.temporal_fusion_strategy = temporal_fusion_strategy |
|
|
self.model = AutoModel.from_pretrained( |
|
|
config.pretrained_config["_name_or_path"], config=config.pretrained_config |
|
|
) |
|
|
self.model.config.use_cache = False |
|
|
self.model.config.pretraining_tp = 1 |
|
|
self.num_years = num_years |
|
|
|
|
|
classifier_dropout = getattr(config, "classifier_dropout", 0.1) or config.hidden_dropout_prob |
|
|
self.dropout = nn.Dropout(classifier_dropout) |
|
|
|
|
|
self.temporal_fusion = TemporalFusion(config.hidden_size, strategy=self.temporal_fusion_strategy, |
|
|
num_years=num_years) |
|
|
|
|
|
|
|
|
self.transformer_encoder = nn.TransformerEncoder( |
|
|
nn.TransformerEncoderLayer( |
|
|
d_model=config.hidden_size, nhead=config.num_attention_heads |
|
|
), |
|
|
num_layers=2, |
|
|
) |
|
|
self.token_classifiers = nn.ModuleDict({ |
|
|
task: nn.Linear(config.hidden_size, num_labels) |
|
|
for task, num_labels in self.num_token_labels_dict.items() |
|
|
}) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
head_mask: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
token_labels: Optional[dict] = None, |
|
|
date_indices: Optional[torch.Tensor] = None, |
|
|
year_index: Optional[torch.Tensor] = None, |
|
|
decade_index: Optional[torch.Tensor] = None, |
|
|
century_index: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: |
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.model.embeddings(input_ids) |
|
|
|
|
|
|
|
|
if self.temporal_fusion_strategy == "early-cross-attention": |
|
|
year_emb = self.temporal_fusion.compute_time_embedding(year_index) |
|
|
inputs_embeds = self.temporal_fusion.cross_attn(inputs_embeds, year_emb) |
|
|
|
|
|
bert_kwargs = { |
|
|
"inputs_embeds": inputs_embeds if self.temporal_fusion_strategy == "early-cross-attention" else None, |
|
|
"input_ids": input_ids if self.temporal_fusion_strategy != "early-cross-attention" else None, |
|
|
"attention_mask": attention_mask, |
|
|
"token_type_ids": token_type_ids, |
|
|
"position_ids": position_ids, |
|
|
"head_mask": head_mask, |
|
|
"output_attentions": output_attentions, |
|
|
"output_hidden_states": output_hidden_states, |
|
|
"return_dict": return_dict, |
|
|
} |
|
|
|
|
|
if any(keyword in self.config.name_or_path.lower() for keyword in ["llama", "deberta"]): |
|
|
bert_kwargs.pop("token_type_ids", None) |
|
|
bert_kwargs.pop("head_mask", None) |
|
|
|
|
|
outputs = self.model(**bert_kwargs) |
|
|
token_output = self.dropout(outputs[0]) |
|
|
hidden_states = list(outputs.hidden_states) if output_hidden_states else None |
|
|
|
|
|
|
|
|
token_output = self.transformer_encoder(token_output.transpose(0, 1)).transpose( |
|
|
0, 1 |
|
|
) |
|
|
|
|
|
if self.temporal_fusion_strategy not in ["baseline", "early-cross-attention"]: |
|
|
token_output = self.temporal_fusion(token_output, year_index) |
|
|
if output_hidden_states: |
|
|
hidden_states.append(token_output) |
|
|
|
|
|
task_logits = {} |
|
|
total_loss = 0 |
|
|
for task, classifier in self.token_classifiers.items(): |
|
|
logits = classifier(token_output) |
|
|
task_logits[task] = logits |
|
|
if token_labels and task in token_labels: |
|
|
loss_fct = CrossEntropyLoss() |
|
|
loss = loss_fct( |
|
|
logits.view(-1, self.num_token_labels_dict[task]), |
|
|
token_labels[task].view(-1), |
|
|
) |
|
|
total_loss += loss |
|
|
|
|
|
if not return_dict: |
|
|
output = (task_logits,) + outputs[2:] |
|
|
return ((total_loss,) + output) if total_loss != 0 else output |
|
|
|
|
|
return TokenClassifierOutput( |
|
|
loss=total_loss, |
|
|
logits=task_logits, |
|
|
hidden_states=tuple(hidden_states) if hidden_states is not None else None, |
|
|
attentions=outputs.attentions if output_attentions else None, |
|
|
) |
|
|
|
|
|
|
|
|
class TemporalFusion(nn.Module): |
|
|
def __init__(self, hidden_size, strategy="add", num_years=327, min_year=1700): |
|
|
super().__init__() |
|
|
self.strategy = strategy |
|
|
self.hidden_size = hidden_size |
|
|
self.min_year = min_year |
|
|
self.max_year = min_year + num_years - 1 |
|
|
|
|
|
self.year_emb = nn.Embedding(num_years, hidden_size) |
|
|
|
|
|
if strategy == "concat": |
|
|
self.concat_proj = nn.Linear(hidden_size * 2, hidden_size) |
|
|
elif strategy == "film": |
|
|
self.film_gamma = nn.Linear(hidden_size, hidden_size) |
|
|
self.film_beta = nn.Linear(hidden_size, hidden_size) |
|
|
elif strategy == "adapter": |
|
|
self.adapter = nn.Sequential( |
|
|
nn.Linear(hidden_size, hidden_size), |
|
|
nn.ReLU(), |
|
|
nn.Linear(hidden_size, hidden_size), |
|
|
) |
|
|
elif strategy == "relative": |
|
|
self.relative_encoder = nn.Sequential( |
|
|
nn.Linear(hidden_size, hidden_size), |
|
|
nn.SiLU(), |
|
|
nn.LayerNorm(hidden_size), |
|
|
) |
|
|
self.film_gamma = nn.Linear(hidden_size, hidden_size) |
|
|
self.film_beta = nn.Linear(hidden_size, hidden_size) |
|
|
elif strategy == "multiscale": |
|
|
self.decade_emb = nn.Embedding(1000, hidden_size) |
|
|
self.century_emb = nn.Embedding(100, hidden_size) |
|
|
elif strategy in ["early-cross-attention", "late-cross-attention"]: |
|
|
self.year_encoder = nn.Sequential( |
|
|
nn.Linear(hidden_size, hidden_size), |
|
|
nn.SiLU() |
|
|
) |
|
|
self.cross_attn = TemporalCrossAttention(hidden_size) |
|
|
|
|
|
def compute_time_embedding(self, year_index): |
|
|
if self.strategy in ["early-cross-attention", "late-cross-attention"]: |
|
|
return self.year_encoder(self.year_emb(year_index)) |
|
|
elif self.strategy == "multiscale": |
|
|
year_index = year_index.long() |
|
|
year = year_index + self.min_year |
|
|
decade = (year // 10).long() |
|
|
century = (year // 100).long() |
|
|
return ( |
|
|
self.year_emb(year_index) + |
|
|
self.decade_emb(decade) + |
|
|
self.century_emb(century) |
|
|
) |
|
|
else: |
|
|
return self.year_emb(year_index) |
|
|
|
|
|
def forward(self, token_output, year_index): |
|
|
B, T, H = token_output.size() |
|
|
|
|
|
if self.strategy == "baseline": |
|
|
return token_output |
|
|
|
|
|
year_emb = self.compute_time_embedding(year_index) |
|
|
|
|
|
if self.strategy == "concat": |
|
|
expanded_year = year_emb.unsqueeze(1).repeat(1, T, 1) |
|
|
fused = torch.cat([token_output, expanded_year], dim=-1) |
|
|
return self.concat_proj(fused) |
|
|
|
|
|
elif self.strategy == "film": |
|
|
gamma = self.film_gamma(year_emb).unsqueeze(1) |
|
|
beta = self.film_beta(year_emb).unsqueeze(1) |
|
|
return gamma * token_output + beta |
|
|
|
|
|
elif self.strategy == "adapter": |
|
|
return token_output + self.adapter(year_emb).unsqueeze(1) |
|
|
|
|
|
elif self.strategy == "add": |
|
|
expanded_year = year_emb.unsqueeze(1).repeat(1, T, 1) |
|
|
return token_output + expanded_year |
|
|
|
|
|
elif self.strategy == "relative": |
|
|
encoded = self.relative_encoder(year_emb) |
|
|
gamma = self.film_gamma(encoded).unsqueeze(1) |
|
|
beta = self.film_beta(encoded).unsqueeze(1) |
|
|
return gamma * token_output + beta |
|
|
|
|
|
elif self.strategy == "multiscale": |
|
|
expanded_year = year_emb.unsqueeze(1).expand(-1, T, -1) |
|
|
return token_output + expanded_year |
|
|
|
|
|
elif self.strategy == "late-cross-attention": |
|
|
return self.cross_attn(token_output, year_emb) |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unknown fusion strategy: {self.strategy}") |
|
|
|
|
|
|
|
|
class TemporalCrossAttention(nn.Module): |
|
|
def __init__(self, hidden_size, num_heads=4): |
|
|
super().__init__() |
|
|
self.attn = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=num_heads, batch_first=True) |
|
|
|
|
|
def forward(self, token_output, time_embedding): |
|
|
|
|
|
time_as_seq = time_embedding.unsqueeze(1) |
|
|
attn_output, _ = self.attn(token_output, time_as_seq, time_as_seq) |
|
|
return token_output + attn_output |
|
|
|