Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast, Qwen2VLForConditionalGeneration | |
| from gui_actor.constants import IGNORE_INDEX | |
| from typing import List, Tuple, Union, Optional | |
| from gui_actor.trainer import rank0_print | |
| class QwenVLwithVisionHeadOutputWithPast(Qwen2VLCausalLMOutputWithPast): | |
| """ | |
| Output class for Qwen2VL with pointer head, extending the base output class. | |
| Args: | |
| lm_loss (`torch.FloatTensor` of shape `(1,)`, *optional*): | |
| Language modeling loss. | |
| pointer_loss (`torch.FloatTensor` of shape `(1,)`, *optional*): | |
| Vision pointer network loss. | |
| pointer_scores (`List[torch.FloatTensor]`, *optional*): | |
| Attention scores from the pointer network, one tensor per batch item. | |
| loss (`torch.FloatTensor` of shape `(1,)`, *optional*): | |
| Combined loss (weighted sum of lm_loss and pointer_loss). | |
| logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): | |
| Prediction scores from the language modeling head. | |
| past_key_values, hidden_states, attentions, rope_deltas: | |
| Same as parent class. | |
| """ | |
| def __init__(self, lm_loss=None, pointer_loss=None, pointer_scores=None, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.lm_loss = lm_loss | |
| self.pointer_loss = pointer_loss | |
| self.pointer_scores = pointer_scores | |
| class VisionHead_MultiPatch(nn.Module): | |
| def __init__(self, d_model, projection_dim, num_attention_heads=8, dropout_rate=0.1): | |
| super().__init__() | |
| self.d_model = d_model | |
| # Note: We omit additional normalization here because Qwen2VL | |
| # already normalizes hidden states using RMSNorm. | |
| self.projection_enc = nn.Sequential( | |
| nn.Linear(d_model, projection_dim), | |
| nn.GELU(), | |
| nn.Linear(projection_dim, d_model) | |
| ) | |
| self.projection_dec = nn.Sequential( | |
| nn.Linear(d_model, projection_dim), | |
| nn.GELU(), | |
| nn.Linear(projection_dim, d_model) | |
| ) | |
| # Add self-attention layer for visual features | |
| self.self_attention = nn.MultiheadAttention( | |
| embed_dim=d_model, | |
| num_heads=num_attention_heads, | |
| dropout=dropout_rate, | |
| batch_first=True | |
| ) | |
| # Layer normalization and residual connection | |
| self.layer_norm = nn.LayerNorm(d_model) | |
| self.dropout = nn.Dropout(dropout_rate) | |
| def forward(self, | |
| hidden_state_enc, # shape: [n_enc, d_model] where n_enc can vary with image size | |
| hidden_state_dec, # shape: [n_dec, d_model] there can be multiple query in one sample | |
| labels: Optional[torch.Tensor] = None, # shape: [n_dec, n_enc], binary mask of patches in bbox | |
| do_single_patch: bool = False, | |
| ): | |
| enc_input = hidden_state_enc.unsqueeze(0) | |
| attn_output, _ = self.self_attention( | |
| query=enc_input, | |
| key=enc_input, | |
| value=enc_input, | |
| # attn_mask=attention_mask, | |
| need_weights=False | |
| ) | |
| # Residual connection and layer normalization | |
| hidden_state_enc_ctx = self.layer_norm(enc_input + self.dropout(attn_output)) | |
| # Remove batch dimension | |
| hidden_state_enc_ctx = hidden_state_enc_ctx.squeeze(0) # [n_enc, d_model] | |
| # Apply the projection networks. | |
| proj_enc = self.projection_enc(hidden_state_enc_ctx) # [n_enc, d_model] | |
| proj_dec = self.projection_dec(hidden_state_dec) # [n_dec, d_model] | |
| # Compute scaled dot-product attention scores. | |
| # Scaling by sqrt(d_model) is critical regardless of variable n_enc. | |
| scaling = self.d_model ** 0.5 | |
| patch_logits = torch.matmul(proj_dec, proj_enc.transpose(0, 1)) / scaling # [n_dec, n_enc] | |
| # Softmax normalization is applied along the encoder dimension. | |
| attn_weights = F.softmax(patch_logits, dim=-1) | |
| loss = None | |
| if (labels is not None) and (not do_single_patch): | |
| epsilon = 1e-8 | |
| labels_float = labels.float() | |
| # Normalize each row to get target probability distribution | |
| target_dist = labels_float / (labels_float.sum(dim=-1, keepdim=True) + epsilon) | |
| # Apply log_softmax to logits | |
| pred_log_probs = F.log_softmax(patch_logits, dim=-1) | |
| # Use KL divergence as loss | |
| loss = F.kl_div(pred_log_probs, target_dist, reduction='batchmean') | |
| if do_single_patch and (labels is not None): | |
| loss = F.cross_entropy(attn_scores, labels) | |
| return attn_weights, loss | |
| class Qwen2VLForConditionalGenerationWithPointer(Qwen2VLForConditionalGeneration): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.multi_patch_pointer_head = VisionHead_MultiPatch(self.config.hidden_size, self.config.hidden_size) | |
| self.pointer_loss_weight = kwargs.get("pointer_loss_weight", 1.0) | |
| self.lm_loss_weight = kwargs.get("lm_loss_weight", 1.0) | |
| self.post_init() | |
| def reset_loss_weights(self, pointer_loss_weight, lm_loss_weight): | |
| self.pointer_loss_weight = pointer_loss_weight | |
| self.lm_loss_weight = lm_loss_weight | |
| def forward(self, | |
| input_ids: torch.LongTensor = None, # (batch_size, seq_len) | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| pixel_values: Optional[torch.Tensor] = None, | |
| pixel_values_videos: Optional[torch.FloatTensor] = None, | |
| image_grid_thw: Optional[torch.LongTensor] = None, | |
| video_grid_thw: Optional[torch.LongTensor] = None, | |
| rope_deltas: Optional[torch.LongTensor] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| # Grounding | |
| visual_token_indices_of_coordinates: Optional[torch.Tensor] = None, # shape: (batch_size, n_target); each element is the ground-truth index of the visual token that should be attended to for the corresponding target token | |
| multi_patch_labels: Optional[torch.Tensor] = None, # shape: list [(n_target, n_visual), ...]; binary mask of patches in bbox | |
| if_multi_patch: bool = True, | |
| coordinates: Optional[List[Tuple[float, float]]] = None, | |
| verbose: bool = False) -> Union[Tuple, QwenVLwithVisionHeadOutputWithPast]: | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if verbose: | |
| rank0_print(f"input_ids: {input_ids.shape}, {input_ids[0][:5]}...") | |
| rank0_print(f"labels: {labels.shape}, {labels[0][:5]}...") | |
| rank0_print(f"pixel_values: {pixel_values.shape}") | |
| rank0_print(f"image_grid_thw: {image_grid_thw.shape}, {image_grid_thw}") | |
| rank0_print(f"coordinates: {coordinates}") | |
| rank0_print(f"visual_token_indices_of_coordinates: {visual_token_indices_of_coordinates}") | |
| rank0_print(f"return_dict: {return_dict}") | |
| if inputs_embeds is None: | |
| inputs_embeds = self.model.embed_tokens(input_ids) # shape: (batch_size, seq_len, d_model) | |
| if pixel_values is not None: | |
| pixel_values = pixel_values.type(self.visual.dtype) | |
| image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) | |
| n_image_tokens = (input_ids == self.config.image_token_id).sum().item() | |
| n_image_features = image_embeds.shape[0] | |
| if n_image_tokens != n_image_features: | |
| raise ValueError( | |
| f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" | |
| ) | |
| image_mask = ( | |
| (input_ids == self.config.image_token_id) | |
| .unsqueeze(-1) | |
| .expand_as(inputs_embeds) | |
| .to(inputs_embeds.device) | |
| ) | |
| image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) | |
| inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) | |
| if pixel_values_videos is not None: | |
| pixel_values_videos = pixel_values_videos.type(self.visual.dtype) | |
| video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) | |
| n_video_tokens = (input_ids == self.config.video_token_id).sum().item() | |
| n_video_features = video_embeds.shape[0] | |
| if n_video_tokens != n_video_features: | |
| raise ValueError( | |
| f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" | |
| ) | |
| video_mask = ( | |
| (input_ids == self.config.video_token_id) | |
| .unsqueeze(-1) | |
| .expand_as(inputs_embeds) | |
| .to(inputs_embeds.device) | |
| ) | |
| video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) | |
| inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) | |
| if attention_mask is not None: | |
| attention_mask = attention_mask.to(inputs_embeds.device) | |
| # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme | |
| if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): | |
| # calculate RoPE index once per generation in the pre-fill stage only | |
| if ( | |
| (cache_position is not None and cache_position[0] == 0) | |
| or self.rope_deltas is None | |
| or (past_key_values is None or past_key_values.get_seq_length() == 0) | |
| ): | |
| position_ids, rope_deltas = self.get_rope_index( | |
| input_ids, image_grid_thw, video_grid_thw, attention_mask | |
| ) | |
| self.rope_deltas = rope_deltas | |
| # then use the prev pre-calculated rope-deltas to get the correct position ids | |
| else: | |
| batch_size, seq_length, _ = inputs_embeds.shape | |
| delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 | |
| position_ids = torch.arange(seq_length, device=inputs_embeds.device) | |
| position_ids = position_ids.view(1, -1).expand(batch_size, -1) | |
| if cache_position is not None: # otherwise `deltas` is an int `0` | |
| delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) | |
| delta = delta.to(position_ids.device) | |
| position_ids = position_ids.add(delta) | |
| position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) | |
| outputs = self.model( | |
| input_ids=None, | |
| position_ids=position_ids, | |
| attention_mask=attention_mask, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| cache_position=cache_position, | |
| ) | |
| hidden_states = outputs[0] # shape: (batch_size, seq_len, d_model) | |
| logits = self.lm_head(hidden_states) | |
| lm_loss = None | |
| if labels is not None and self.lm_loss_weight > 0: | |
| # Upcast to float if we need to compute the loss to avoid potential precision issues | |
| logits = logits.float() | |
| # Shift so that tokens < n predict n | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = labels[..., 1:].contiguous() | |
| # Flatten the tokens | |
| loss_fct = nn.CrossEntropyLoss() | |
| shift_logits = shift_logits.view(-1, self.config.vocab_size) | |
| shift_labels = shift_labels.view(-1) | |
| # Enable model parallelism | |
| shift_labels = shift_labels.to(shift_logits.device) | |
| lm_loss = loss_fct(shift_logits, shift_labels) | |
| # If vision supervision is requested, process the action head. | |
| pointer_loss = None | |
| pointer_scores = [] | |
| if visual_token_indices_of_coordinates is not None: | |
| batch_size = input_ids.shape[0] | |
| pointer_losses = [] | |
| # Process each sample individually because the number of visual and target tokens may vary. | |
| for i in range(batch_size): | |
| dummy_target = False | |
| # Get the token ids and corresponding hidden states for sample i. | |
| token_ids = input_ids[i] # shape: (seq_length,) | |
| hs = hidden_states[i] # shape: (seq_length, d_model) | |
| # Identify visual tokens indices. | |
| visual_mask = (token_ids == self.config.image_token_id) | |
| visual_indices = torch.nonzero(visual_mask, as_tuple=False).squeeze(-1) # shape: (n_visual,) | |
| # Identify target tokens (the ones that should attend to visual features). | |
| target_mask = (token_ids == self.config.pointer_pad_token_id) | |
| target_indices = torch.nonzero(target_mask, as_tuple=False).squeeze(-1) | |
| # If either visual or target tokens are missing, skip this sample. | |
| if visual_indices.numel() == 0: | |
| raise ValueError(f"No visual or target tokens found for sample {i}.") | |
| if target_indices.numel() == 0: | |
| target_indices = torch.tensor([hs.shape[0] - 1]) # take the last token as the dummy target token | |
| gt = torch.tensor([0]).to(hs.device) # take the first visual token as the dummy ground truth | |
| if if_multi_patch: # task the first 4 visual tokens as the ground truth | |
| sample_labels = torch.zeros_like(visual_indices).unsqueeze(0) | |
| sample_labels[0][:4] = 1 | |
| dummy_target = True | |
| else: | |
| # For supervision, we assume that visual_token_indices_of_coordinates[i] is a tensor of shape (n_target,) | |
| # where each element is an integer in the range [0, n_visual-1] indicating the ground-truth visual token. | |
| gt = visual_token_indices_of_coordinates[i].to(hs.device) # shape: (n_target,) | |
| if if_multi_patch: | |
| sample_labels = multi_patch_labels[i] | |
| # Gather the corresponding hidden state representations. | |
| # visual_hidden = hs[visual_indices] # shape: (n_visual, d_model) | |
| visual_embeds = inputs_embeds[i][visual_indices] | |
| target_hidden = hs[target_indices] # shape: (n_target, d_model) | |
| # Calculate loss for multi-patch mode | |
| if if_multi_patch: | |
| # Ensure the number of targets matches between sample and labels | |
| if sample_labels.shape[0] != target_indices.shape[0]: | |
| raise ValueError(f"Sample {i} has mismatched target counts: {sample_labels.shape[0]} labels but found {target_indices.shape[0]} target tokens") | |
| # Process using VisionHead_MultiPatch | |
| attn_scores, loss_v = self.multi_patch_pointer_head( | |
| visual_embeds, | |
| target_hidden, | |
| labels=sample_labels | |
| ) | |
| else: | |
| # Deprecated branch - single patch mode is no longer used | |
| # Run the action head to compute the attention (from target tokens to visual tokens) and its loss. | |
| attn_scores, loss_v = self.pointer_head(visual_embeds, target_hidden, labels=gt) | |
| pointer_scores.append(attn_scores.detach().cpu()) | |
| pointer_losses.append(loss_v * 0.0 if dummy_target else loss_v) | |
| pointer_loss = torch.stack(pointer_losses).mean() | |
| # Combine the LM loss and vision loss using the provided loss weights. | |
| if lm_loss is None: | |
| total_loss = pointer_loss | |
| elif pointer_loss is None: | |
| total_loss = lm_loss | |
| else: | |
| total_loss = self.lm_loss_weight * lm_loss + self.pointer_loss_weight * pointer_loss | |
| if return_dict: | |
| return QwenVLwithVisionHeadOutputWithPast( | |
| lm_loss=lm_loss, | |
| pointer_loss=pointer_loss, | |
| pointer_scores=pointer_scores, | |
| loss=total_loss, | |
| logits=logits, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| rope_deltas=self.rope_deltas, | |
| ) | |
| else: | |
| # When labels are provided, parent's forward returns a tuple with loss as the first element. | |
| if labels is not None: | |
| # Replace the LM loss with the combined loss. | |
| output = (lm_loss, pointer_loss, logits, pointer_scores,) + outputs[1:] | |
| print(f"returning: total_loss, logits, pointer_scores, ...") | |
| return (total_loss,) + output if total_loss is not None else output | |
| else: | |
| return outputs |