| from torch import Tensor | |
| import torch | |
| def pool(last_hidden_states: Tensor, | |
| attention_mask: Tensor, | |
| pool_type: str) -> Tensor: | |
| last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) | |
| if pool_type == "avg": | |
| emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | |
| elif pool_type == "weighted_avg": | |
| emb = last_hidden.sum(dim=1) | |
| elif pool_type == "cls": | |
| emb = last_hidden[:, 0] | |
| elif pool_type == "last": | |
| left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) | |
| if left_padding: | |
| emb = last_hidden[:, -1] | |
| else: | |
| sequence_lengths = attention_mask.sum(dim=1) - 1 | |
| batch_size = last_hidden.shape[0] | |
| emb = last_hidden[torch.arange(batch_size, device=last_hidden.device), sequence_lengths] | |
| else: | |
| raise ValueError(f"pool_type {pool_type} not supported") | |
| return emb |