File size: 4,598 Bytes
f2d0a8e 6a27d7e f2d0a8e 6a27d7e f2d0a8e 6a27d7e f2d0a8e 6a27d7e f2d0a8e 6a27d7e f2d0a8e 6a27d7e f2d0a8e 6a27d7e f2d0a8e 6a27d7e f2d0a8e 6a27d7e f2d0a8e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
from transformers import PretrainedConfig, PreTrainedModel, AutoProcessor, SiglipModel
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
class ExplainerConfig(PretrainedConfig):
model_type = "explainer"
def __init__(self, base_model_name='google/siglip2-giant-opt-patch16-384',
hidden_dim=768, giant=True, **kwargs):
self.base_model_name = base_model_name
self.hidden_dim = hidden_dim
self.giant = giant
super().__init__(**kwargs)
class SigLIPBBoxRegressor(nn.Module):
def __init__(self, siglip_model, hidden_dim=768, giant=True):
super().__init__()
self.siglip = siglip_model
vision_dim = self.siglip.vision_model.config.hidden_size
text_dim = self.siglip.text_model.config.hidden_size
if giant: text_dim = 1536
# Feature fusion layers
self.vision_projector = nn.Sequential(
nn.Linear(vision_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.1)
)
self.text_projector = nn.Sequential(
nn.Linear(text_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.1)
)
# Cross-modal fusion
self.fusion_layer = nn.Sequential(
nn.Linear(hidden_dim*2, hidden_dim),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim, hidden_dim//2),
nn.ReLU(),
nn.Dropout(0.1)
)
self.topleft_regressor = nn.Sequential(
nn.Linear(hidden_dim//2, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 2), # (x1, y1)
)
self.bottomright_regressor = nn.Sequential(
nn.Linear(hidden_dim//2, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 2), # (x2, y2)
)
def forward(self, pixel_values, input_ids):
with torch.no_grad():
outputs = self.siglip(pixel_values=pixel_values, input_ids=input_ids, return_dict=True)
# Extract pooled features
vision_features = outputs.image_embeds.float()
text_features = outputs.text_embeds.float()
# Project features
vision_proj = self.vision_projector(vision_features)
text_proj = self.text_projector(text_features)
# Fuse modalities
fused = torch.cat([vision_proj, text_proj], dim=1)
fused_features = self.fusion_layer(fused)
# Predict bbox
topleft_pred = self.topleft_regressor(fused_features)
bottomright_pred = self.bottomright_regressor(fused_features)
return torch.cat([topleft_pred, bottomright_pred], dim=1)
class Explainer(PreTrainedModel):
config_class = ExplainerConfig
def __init__(self, config):
super().__init__(config)
self.siglip_model = SiglipModel.from_pretrained(config.base_model_name)
self.bbox_regressor = SigLIPBBoxRegressor(self.siglip_model)
self.processor = AutoProcessor.from_pretrained(config.base_model_name, use_fast=True)
def forward(self, pixel_values=None, input_ids=None):
return self.bbox_regressor(pixel_values, input_ids)
def predict(self, image, text, device="cuda"):
self.to(device)
self.eval()
inputs = self.processor(
text=text,
images=image,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=64
)
pixel_values = inputs["pixel_values"].to(device).half()
input_ids = inputs["input_ids"].to(device)
with torch.no_grad():
pred_bbox = self.forward(pixel_values, input_ids)
return pred_bbox[0].cpu().numpy().tolist()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
config = kwargs.pop("config", None)
if config is None:
config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path)
model = cls(config)
checkpoint_path = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="pytorch_model.bin"
)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.siglip_model.load_state_dict(checkpoint["siglip_model"])
model.bbox_regressor.load_state_dict(checkpoint["bbox_regressor"])
return model
|