Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import Wav2Vec2Model | |
| class Wav2VecIntent(nn.Module): | |
| def __init__(self, num_classes=31, pretrained_model="facebook/wav2vec2-large"): | |
| super().__init__() | |
| # Load pretrained wav2vec model | |
| self.wav2vec = Wav2Vec2Model.from_pretrained(pretrained_model) | |
| # Get hidden size from model config | |
| hidden_size = self.wav2vec.config.hidden_size | |
| # Add layer normalization | |
| self.layer_norm = nn.LayerNorm(hidden_size) | |
| # Add attention mechanism | |
| self.attention = nn.Linear(hidden_size, 1) | |
| # Add dropout for regularization | |
| self.dropout = nn.Dropout(p=0.5) | |
| # Classification head | |
| self.fc = nn.Linear(hidden_size, num_classes) | |
| def forward(self, input_values, attention_mask=None): | |
| # Get wav2vec features | |
| outputs = self.wav2vec( | |
| input_values, | |
| attention_mask=attention_mask, | |
| return_dict=True | |
| ) | |
| hidden_states = outputs.last_hidden_state # [batch, sequence, hidden] | |
| # Apply layer normalization | |
| hidden_states = self.layer_norm(hidden_states) | |
| # Apply attention | |
| attn_weights = F.softmax(self.attention(hidden_states), dim=1) | |
| x = torch.sum(hidden_states * attn_weights, dim=1) # Weighted sum | |
| # Apply dropout | |
| x = self.dropout(x) | |
| # Final classification | |
| x = self.fc(x) | |
| return x |