| import torch | |
| import torch.utils.checkpoint | |
| from torch import nn | |
| from transformers.models.whisper.configuration_whisper import WhisperConfig | |
| from transformers.models.whisper.modeling_whisper import ( | |
| WhisperEncoderLayer, | |
| WhisperEncoder, | |
| WhisperModel, | |
| WhisperForConditionalGeneration, | |
| ) | |
| from .configuration_lite_whisper import LiteWhisperConfig | |
| class LinearLowRank(nn.Module): | |
| def __init__( | |
| self, | |
| in_features: int, | |
| out_features: int, | |
| low_rank_features: int, | |
| ): | |
| super().__init__() | |
| self.weight1 = nn.Parameter(torch.randn(in_features, low_rank_features)) | |
| self.weight2 = nn.Parameter(torch.randn(low_rank_features, out_features)) | |
| self.bias = nn.Parameter(torch.zeros(out_features)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return (x @ self.weight1) @ self.weight2 + self.bias | |
| class LiteWhisperEncoderLayer(WhisperEncoderLayer): | |
| def __init__(self, config: WhisperConfig, low_rank_config: dict[str, int]): | |
| super().__init__(config) | |
| if "k_proj" in low_rank_config: | |
| self.self_attn.k_proj = LinearLowRank(self.embed_dim, self.embed_dim, low_rank_config["k_proj"]) | |
| if "v_proj" in low_rank_config: | |
| self.self_attn.v_proj = LinearLowRank(self.embed_dim, self.embed_dim, low_rank_config["v_proj"]) | |
| if "q_proj" in low_rank_config: | |
| self.self_attn.q_proj = LinearLowRank(self.embed_dim, self.embed_dim, low_rank_config["q_proj"]) | |
| if "out_proj" in low_rank_config: | |
| self.self_attn.out_proj = LinearLowRank(self.embed_dim, self.embed_dim, low_rank_config["out_proj"]) | |
| if "fc1" in low_rank_config: | |
| self.fc1 = LinearLowRank(self.embed_dim, config.encoder_ffn_dim, low_rank_config["fc1"]) | |
| if "fc2" in low_rank_config: | |
| self.fc2 = LinearLowRank(config.encoder_ffn_dim, self.embed_dim, low_rank_config["fc2"]) | |
| class LiteWhisperEncoder(WhisperEncoder): | |
| def __init__(self, config: WhisperConfig, low_rank_config: list[dict[str, int]]): | |
| super().__init__(config) | |
| self.layers = nn.ModuleList([ | |
| LiteWhisperEncoderLayer(config, low_rank_config[i]) | |
| for i in range(config.encoder_layers) | |
| ]) | |
| class LiteWhisperModel(WhisperModel): | |
| def __init__(self, config: WhisperConfig, low_rank_config: list[dict[str, int]]): | |
| super().__init__(config) | |
| self.encoder = LiteWhisperEncoder(config, low_rank_config) | |
| class LiteWhisperForConditionalGeneration(WhisperForConditionalGeneration): | |
| config_class = LiteWhisperConfig | |
| def __init__(self, config: LiteWhisperConfig): | |
| low_rank_config = getattr(config, "low_rank_config", None) | |
| super().__init__(config) | |
| self.model = LiteWhisperModel(config, low_rank_config) | |