Spaces:
Paused
Paused
Update lavila/models/prompt_tuning.py
Browse files
lavila/models/prompt_tuning.py
CHANGED
|
@@ -58,7 +58,7 @@ class PromptPoolLearner(nn.Module):
|
|
| 58 |
|
| 59 |
if istrain:
|
| 60 |
inv_freq = self.id_table.sum() / self.id_table.float()
|
| 61 |
-
weights =
|
| 62 |
idx = torch.multinomial(weights, k, replacement=False)
|
| 63 |
else:
|
| 64 |
idx = torch.argsort(similarity, dim=-1, descending=True)[:, :k]
|
|
@@ -74,8 +74,8 @@ class PromptPoolLearner(nn.Module):
|
|
| 74 |
out['prompts'] = prompts
|
| 75 |
sel_sim = similarity[torch.arange(BZ).view(-1, 1), idx]
|
| 76 |
sel_key = keys[idx.flatten(), ...].view(BZ, k, self.prompt_dim)
|
| 77 |
-
diff = F.mse_loss((sel_sim.unsqueeze(1) @ sel_key).squeeze(), query.detach(), reduction='sum') / BZ
|
| 78 |
-
ksim = torch.
|
| 79 |
out['ps_loss'] = diff + ksim
|
| 80 |
|
| 81 |
return out
|
|
|
|
| 58 |
|
| 59 |
if istrain:
|
| 60 |
inv_freq = self.id_table.sum() / self.id_table.float()
|
| 61 |
+
weights = (similarity + 1) / 2 * gamma + (1 - gamma) * torch.softmax(inv_freq, dim=-1)
|
| 62 |
idx = torch.multinomial(weights, k, replacement=False)
|
| 63 |
else:
|
| 64 |
idx = torch.argsort(similarity, dim=-1, descending=True)[:, :k]
|
|
|
|
| 74 |
out['prompts'] = prompts
|
| 75 |
sel_sim = similarity[torch.arange(BZ).view(-1, 1), idx]
|
| 76 |
sel_key = keys[idx.flatten(), ...].view(BZ, k, self.prompt_dim)
|
| 77 |
+
diff = F.mse_loss((sel_sim.unsqueeze(1) @ sel_key).squeeze(1), query.detach(), reduction='sum') / BZ
|
| 78 |
+
ksim = torch.sum(torch.abs(torch.matmul(keys, keys.t()) - torch.eye(self.size).to(keys.device))) / BZ
|
| 79 |
out['ps_loss'] = diff + ksim
|
| 80 |
|
| 81 |
return out
|