Spaces:
Running
Running
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from attention import SelfAttention | |
| class CLIPEmbedding(nn.Module): | |
| def __init__(self, n_vocab, n_embed, n_token): | |
| super().__init__() | |
| self.token_embedding=nn.Embedding(n_vocab, n_embed) | |
| self.position_embedding=nn.Parameter(torch.zeros((n_token, n_embed))) | |
| def forward(self, tokens: torch.Tensor): | |
| x=self.token_embedding(tokens) | |
| x+=self.position_embedding | |
| return x | |
| class CLIPLayer(nn.Module): | |
| def __init__(self, n_head, n_embed): | |
| super().__init__() | |
| self.layernorm_1=nn.LayerNorm(n_embed) | |
| self.attention=SelfAttention(n_head, n_embed) | |
| self.layernorm_2=nn.LayerNorm(n_embed) | |
| self.linear_1=nn.Linear(n_embed, 4*n_embed) | |
| self.linear_2=nn.Linear(4*n_embed, n_embed) | |
| def forward(self, x): | |
| residue=x | |
| x=self.layernorm_1(x) | |
| x=self.attention(x, causal_mask=True) | |
| x+=residue | |
| residue=x | |
| x=self.layernorm_2(x) | |
| x=self.linear_1(x) | |
| x=x*torch.sigmoid(1.702*x) | |
| x=self.linear_2(x) | |
| x+=residue | |
| return x | |
| class CLIP(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.embedding=CLIPEmbedding(49408, 768, 77) | |
| self.layers=nn.ModuleList([ | |
| CLIPLayer(12, 768) for i in range(12) | |
| ]) | |
| self.layernorm=nn.LayerNorm(768) | |
| def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor: | |
| tokens=tokens.type(torch.long) | |
| state=self.embedding(tokens) | |
| for layer in self.layers: | |
| state=layer(state) | |
| output=self.layernorm(state) | |
| return output | |
| if __name__ == "__main__": | |
| dummy_tokens = torch.randint(0, 49408, (1, 77)) # (Batch_Size, Seq_Len) | |
| # Instantiate the model | |
| model = CLIP() | |
| # Forward pass | |
| with torch.no_grad(): # no need to track gradients for testing | |
| output = model(dummy_tokens) | |
| # Print the output shape | |
| # Output shape: torch.Size([1, 77, 768]) | |
| print("Output shape:", output.shape) | |