Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # coding=utf-8 | |
| # Copyright 2022 Vladislav Lialin and Namrata Shivagunde | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| import torch.nn as nn | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, input_size, hidden, num_heads, causal=False): | |
| """Multi-head attention module which computes [softmax(xQ_h @ xK_h^T) @ xV: ...] @ U | |
| Can work as both self-attention or cross-attention (if kv is provided to .forward). | |
| Args: | |
| causal: use causal masking (do not allow target to look to the future or current token of source) | |
| """ | |
| if hidden % num_heads: | |
| raise ValueError(f"hidden should be divisible by num_heads, " | |
| f"but got hidden={hidden} and num_heads={num_heads}") | |
| super().__init__() | |
| self.k = nn.Linear(input_size, hidden) | |
| self.q = nn.Linear(input_size, hidden) | |
| self.v = nn.Linear(input_size, hidden) | |
| self.mix = nn.Linear(hidden, hidden) | |
| self.num_heads = num_heads | |
| self.head_size = hidden // num_heads | |
| self.scale = self.head_size ** 0.5 | |
| self.causal = causal # causal masking | |
| def forward(self, q, kv=None, key_padding_mask=None, return_attention=False): | |
| """[Softmax(source Q_1 @ target K_1^T) @ target V_1 : ... ) @ x V_heads] @ U | |
| Performs self-attention if kv is not specified. | |
| In this case, kv = q and kv_seq_len = query_seq_len. | |
| Args: | |
| q: FloatTensor[batch_size, query_seq_len, input_size] | |
| kv (target) : optional, FloatTensor[batch_size, kv_seq_len, input_size] | |
| key_padding_mask: BoolTensor[batch_size, kv_seq_len] 0 means unpadded, 1 means padded | |
| Returns: | |
| FloatTensor[batch_size, seq_len, hidden] | |
| """ | |
| # Task 1.1 (1 point) | |
| # Update this function with cross-attention mechanism | |
| # If target is None, then target (kv) and source (q) will be same. | |
| # Define k, q, v using self.k, self.q and self.v based on if the target exists or not | |
| # Note : Please write shape of each tensor for each line of code | |
| ## YOUR CODE STARTS HERE## ~ 2 lines code | |
| k = self.k(kv) if kv!=None else self.k(q) | |
| # print('k', k.shape, 'q', q.shape) | |
| q = self.q(q) | |
| v = self.v(kv) if kv!=None else self.v(q) | |
| # print("KV", kv) | |
| # YOUR CODE ENDS HERE | |
| bs, attending_seq, _ = q.shape | |
| attended_seq = k.shape[1] | |
| # [b, s, h] -> [b, h, s] -> [b * heads, h / heads, s] -> [b * heads, s, h / heads] | |
| k = k.transpose(1, 2).reshape(bs * self.num_heads, self.head_size, -1).transpose(1, 2).contiguous() # [batch * num_heads, seq, hidden / num_heads] | |
| q = q.transpose(1, 2).reshape(bs * self.num_heads, self.head_size, -1).transpose(1, 2).contiguous() | |
| v = v.transpose(1, 2).reshape(bs * self.num_heads, self.head_size, -1).transpose(1, 2).contiguous() | |
| scores = q @ k.transpose(1, 2) / self.scale # [batch * num_heads, attending_seq, attended_seq] | |
| assert scores.shape == (bs * self.num_heads, attending_seq, attended_seq) | |
| if key_padding_mask is not None: | |
| # Task 1.2 (1 point) | |
| # Padding | |
| # Set the scores corresponding to padded positions (key_padding_mask == 1) to -inf | |
| # | |
| # You might need to reshape the scores to [batch_size, seq_len, seq_len] | |
| # in this case, remember to reshape them back | |
| # Our implementation is 3 lines | |
| # YOUR CODE STARTS HERE | |
| # print(scores.shape, key_padding_mask.unsqueeze(-2).shape) | |
| scores = scores.reshape(self.num_heads, bs, attending_seq, attended_seq) | |
| scores_check = scores.reshape(bs, self.num_heads, attending_seq, -1) | |
| # print("Socres:", scores.shape, "Scores_Check:", scores_check.shape) | |
| # print('----') | |
| scores = scores.masked_fill(key_padding_mask.unsqueeze(-2)==1, value = float("-inf")) | |
| scores = scores.view(bs * self.num_heads, attending_seq, attended_seq) | |
| # YOUR CODE ENDS HERE | |
| assert scores.size() == (bs * self.num_heads, attending_seq, attended_seq),\ | |
| f"scores have wrong shape. Expected {(bs * self.num_heads, attending_seq, attended_seq)}, got {scores.size()}" | |
| if self.causal: | |
| causal_mask = torch.triu(torch.ones(attending_seq, attended_seq, dtype=torch.bool, device=scores.device), diagonal=1) | |
| scores.masked_fill_(causal_mask.bool().unsqueeze(0), float("-inf")) | |
| probs = torch.softmax(scores, dim=-1) # [batch * num_heads, tgt_seq, src_seq] | |
| att = probs @ v # [batch * num_heads, tgt_seq, hidden / num_heads] | |
| # [b * heads, s, h / heads] -> [b * heads, h / heads, s] -> [b, h, s] -> [b, s, h] | |
| att = att.transpose(1, 2).reshape(bs, -1, attending_seq).transpose(1, 2).contiguous() | |
| att = self.mix(att) | |
| if return_attention: | |
| return att, probs | |
| return att | |