Update wav2vec2speechclassification.py
Browse files
wav2vec2speechclassification.py
CHANGED
|
@@ -2,7 +2,7 @@ from dataclasses import dataclass
|
|
| 2 |
from typing import Optional, Tuple
|
| 3 |
import torch
|
| 4 |
from transformers.file_utils import ModelOutput
|
| 5 |
-
from
|
| 6 |
|
| 7 |
|
| 8 |
@dataclass
|
|
@@ -25,8 +25,7 @@ from transformers.models.wav2vec2.modeling_wav2vec2 import (
|
|
| 25 |
|
| 26 |
class Wav2Vec2ClassificationHead(nn.Module):
|
| 27 |
"""Head for wav2vec classification task."""
|
| 28 |
-
config_class =
|
| 29 |
-
model_type = "wav2vec2"
|
| 30 |
|
| 31 |
def __init__(self, config):
|
| 32 |
super().__init__()
|
|
@@ -45,8 +44,7 @@ class Wav2Vec2ClassificationHead(nn.Module):
|
|
| 45 |
|
| 46 |
|
| 47 |
class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel):
|
| 48 |
-
config_class =
|
| 49 |
-
model_type = "wav2vec2"
|
| 50 |
|
| 51 |
def __init__(self, config):
|
| 52 |
super().__init__(config)
|
|
|
|
| 2 |
from typing import Optional, Tuple
|
| 3 |
import torch
|
| 4 |
from transformers.file_utils import ModelOutput
|
| 5 |
+
from .wav2vec2fsr_config import W2V2FSRConfig
|
| 6 |
|
| 7 |
|
| 8 |
@dataclass
|
|
|
|
| 25 |
|
| 26 |
class Wav2Vec2ClassificationHead(nn.Module):
|
| 27 |
"""Head for wav2vec classification task."""
|
| 28 |
+
config_class = W2V2FSRConfig
|
|
|
|
| 29 |
|
| 30 |
def __init__(self, config):
|
| 31 |
super().__init__()
|
|
|
|
| 44 |
|
| 45 |
|
| 46 |
class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel):
|
| 47 |
+
config_class = W2V2FSRConfig
|
|
|
|
| 48 |
|
| 49 |
def __init__(self, config):
|
| 50 |
super().__init__(config)
|