Spaces:
Build error
Build error
| import torch.nn as nn | |
| class Projections(nn.Module): | |
| def __init__(self, clip_embed, phi_embed, num_projection_layers=6): | |
| super().__init__() | |
| self.output = nn.Linear(clip_embed, phi_embed) | |
| self.norm = nn.LayerNorm(phi_embed) | |
| self.projection_layers = nn.ModuleList( | |
| [ | |
| nn.Sequential( | |
| nn.Linear(phi_embed, phi_embed), | |
| nn.GELU(), | |
| nn.Linear(phi_embed, phi_embed), | |
| ) | |
| for _ in range(num_projection_layers) | |
| ] | |
| ) | |
| def forward(self, x): | |
| x = self.output(x) | |
| x = self.norm(x) | |
| for layer in self.projection_layers: | |
| residual = x | |
| x = layer(x) + residual | |
| return x |