mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 10:44:19 +00:00
add monkey patch codes
This commit is contained in:
parent
e1a55b945b
commit
e49534f2dd
@ -0,0 +1,29 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import whisper
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
||||
the mel spectrogram of the audio
|
||||
"""
|
||||
x = F.gelu(self.conv1(x))
|
||||
x = F.gelu(self.conv2(x))
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
x = (x + self.positional_embedding[: x.shape[1], :]).to(x.dtype)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.ln_post(x)
|
||||
return x
|
||||
|
||||
|
||||
def replace_whisper_encoder_forward():
|
||||
"""
|
||||
This function monkey patches the forward method of the whisper encoder.
|
||||
To be called before the model is loaded, it changes whisper to process audio with any length < 30s.
|
||||
"""
|
||||
whisper.model.AudioEncoder.forward = forward
|
Loading…
x
Reference in New Issue
Block a user