add monkey patch codes

This commit is contained in:
Yuekai Zhang 2024-01-25 14:03:51 +08:00
parent e1a55b945b
commit e49534f2dd

View File

@ -0,0 +1,29 @@
import torch
import torch.nn.functional as F
import whisper
def forward(self, x: torch.Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)
x = (x + self.positional_embedding[: x.shape[1], :]).to(x.dtype)
for block in self.blocks:
x = block(x)
x = self.ln_post(x)
return x
def replace_whisper_encoder_forward():
"""
This function monkey patches the forward method of the whisper encoder.
To be called before the model is loaded, it changes whisper to process audio with any length < 30s.
"""
whisper.model.AudioEncoder.forward = forward