This commit is contained in:
yifanyeung 2023-12-23 15:27:34 +08:00
parent 77125064cb
commit 75c5389979
9 changed files with 42 additions and 6736 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../ASR/zipformer/beam_search.py

View File

@ -22,39 +22,39 @@
Usage:
(1) ctc-decoding
./zipformer/ctc_decode.py \
./hubert/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--exp-dir ./hubert/exp \
--use-ctc 1 \
--max-duration 600 \
--decoding-method ctc-decoding
(2) 1best
./zipformer/ctc_decode.py \
./hubert/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--exp-dir ./hubert/exp \
--use-ctc 1 \
--max-duration 600 \
--hlg-scale 0.6 \
--decoding-method 1best
(3) nbest
./zipformer/ctc_decode.py \
./hubert/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--exp-dir ./hubert/exp \
--use-ctc 1 \
--max-duration 600 \
--hlg-scale 0.6 \
--decoding-method nbest
(4) nbest-rescoring
./zipformer/ctc_decode.py \
./hubert/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--exp-dir ./hubert/exp \
--use-ctc 1 \
--max-duration 600 \
--hlg-scale 0.6 \
@ -63,10 +63,10 @@ Usage:
--decoding-method nbest-rescoring
(5) whole-lattice-rescoring
./zipformer/ctc_decode.py \
./hubert/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--exp-dir ./hubert/exp \
--use-ctc 1 \
--max-duration 600 \
--hlg-scale 0.6 \
@ -164,7 +164,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
default="hubert/exp",
help="The experiment dir",
)
@ -340,7 +340,7 @@ def decode_one_batch(
feature_lens = supervisions["num_frames"].to(device)
if params.causal:
# this seems to cause insertions at the end of the utterance if used with zipformer.
# this seems to cause insertions at the end of the utterance if used with hubert.
pad_len = 30
feature_lens += pad_len
feature = torch.nn.functional.pad(

View File

@ -92,9 +92,9 @@ class HubertAsrDataset(torch.utils.data.Dataset):
feature_size=1,
sampling_rate=16000,
padding_side="right",
padding_value=0.0,
padding_value=0,
do_normalize=True,
return_attention_mask=True,
return_attention_mask=False,
)
def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
@ -148,7 +148,7 @@ if __name__ == "__main__":
)
for batch_idx, batch in enumerate(dl):
import pdb
pdb.set_trace()
pass
print(batch["audio"])
print(batch["audio_lens"])
print(batch["supervisions"]["text"])
print(batch["cuts"])

View File

@ -1,134 +0,0 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
from scaling import Balancer
class Decoder(nn.Module):
"""This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
It removes the recurrent connection from the decoder, i.e., the prediction
network. Different from the above paper, it adds an extra Conv1d
right after the embedding layer.
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
"""
def __init__(
self,
vocab_size: int,
decoder_dim: int,
blank_id: int,
context_size: int,
):
"""
Args:
vocab_size:
Number of tokens of the modeling unit including blank.
decoder_dim:
Dimension of the input embedding, and of the decoder output.
blank_id:
The ID of the blank symbol.
context_size:
Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
"""
super().__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=decoder_dim,
)
# the balancers are to avoid any drift in the magnitude of the
# embeddings, which would interact badly with parameter averaging.
self.balancer = Balancer(
decoder_dim,
channel_dim=-1,
min_positive=0.0,
max_positive=1.0,
min_abs=0.5,
max_abs=1.0,
prob=0.05,
)
self.blank_id = blank_id
assert context_size >= 1, context_size
self.context_size = context_size
self.vocab_size = vocab_size
if context_size > 1:
self.conv = nn.Conv1d(
in_channels=decoder_dim,
out_channels=decoder_dim,
kernel_size=context_size,
padding=0,
groups=decoder_dim // 4, # group size == 4
bias=False,
)
self.balancer2 = Balancer(
decoder_dim,
channel_dim=-1,
min_positive=0.0,
max_positive=1.0,
min_abs=0.5,
max_abs=1.0,
prob=0.05,
)
else:
# To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
# when inference with torch.jit.script and context_size == 1
self.conv = nn.Identity()
self.balancer2 = nn.Identity()
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U).
need_pad:
True to left pad the input. Should be True during training.
False to not pad the input. Should be False during inference.
Returns:
Return a tensor of shape (N, U, decoder_dim).
"""
y = y.to(torch.int64)
# this stuff about clamp() is a temporary fix for a mismatch
# at utterance start, we use negative ids in beam_search.py
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
embedding_out = self.balancer(embedding_out)
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = F.relu(embedding_out)
embedding_out = self.balancer2(embedding_out)
return embedding_out

View File

@ -0,0 +1 @@
../../ASR/zipformer/decoder.py

View File

@ -64,7 +64,6 @@ from lhotse.utils import fix_random_seed
from model import AsrModel
from optim import Eden, ScaledAdam
from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
@ -152,7 +151,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--do-stable-layer-norm",
type=str2bool,
default=True,
default=False,
)
parser.add_argument(
"--feat-extract-activation",
@ -162,12 +161,12 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--feat-extract-norm",
type=str,
default="layer",
default="group",
)
parser.add_argument(
"--feat-proj-dropout",
type=float,
default=0.0,
default=0.1,
)
parser.add_argument(
"--feat-proj-layer-norm",
@ -192,7 +191,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--hidden-size",
type=int,
default=1024,
default=768,
)
parser.add_argument(
"--initializer-range",
@ -202,7 +201,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--intermediate-size",
type=int,
default=4096,
default=3072,
)
parser.add_argument(
"--layer-norm-eps",
@ -247,7 +246,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--num-attention-heads",
type=int,
default=16,
default=12,
)
parser.add_argument(
"--num-conv-pos-embedding-groups",
@ -262,14 +261,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--num-hidden-layers",
type=int,
default=24,
)
parser.add_argument(
"--encoder-dim",
type=int,
default=1024,
help="Embedding dimension in encoder model.",
default=12,
)
parser.add_argument(
@ -366,6 +358,14 @@ def get_parser():
""",
)
parser.add_argument(
"--pretrained-dir",
type=str,
default="download/hubert-base-ls960",
help="""The pretrained model dir.
It specifies the directory where the pretrained checkpoint is saved.""",
)
parser.add_argument(
"--bpe-model",
type=str,
@ -657,7 +657,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
encoder_dim=params.encoder_dim,
encoder_dim=params.hidden_size,
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
@ -685,7 +685,7 @@ def get_model(params: AttributeDict) -> nn.Module:
encoder=encoder,
decoder=decoder,
joiner=joiner,
encoder_dim=params.encoder_dim,
encoder_dim=params.hidden_size,
decoder_dim=params.decoder_dim,
vocab_size=params.vocab_size,
use_transducer=params.use_transducer,
@ -731,6 +731,8 @@ def load_checkpoint_if_available(
elif params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else:
logging.info(f"Loading {params.pretrained_dir}")
model.encoder = HubertModel.from_pretrained(params.pretrained_dir)
return None
assert filename.is_file(), f"{filename} does not exist!"

View File

@ -1,67 +0,0 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from scaling import ScaledLinear
class Joiner(nn.Module):
def __init__(
self,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
):
super().__init__()
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25)
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
self.output_linear = nn.Linear(joiner_dim, vocab_size)
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
project_input: bool = True,
) -> torch.Tensor:
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, s_range, C).
decoder_out:
Output from the decoder. Its shape is (N, T, s_range, C).
project_input:
If true, apply input projections encoder_proj and decoder_proj.
If this is false, it is the user's responsibility to do this
manually.
Returns:
Return a tensor of shape (N, T, s_range, C).
"""
assert encoder_out.ndim == decoder_out.ndim, (
encoder_out.shape,
decoder_out.shape,
)
if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
else:
logit = encoder_out + decoder_out
logit = self.output_linear(torch.tanh(logit))
return logit

View File

@ -0,0 +1 @@
../../ASR/zipformer/joiner.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../ASR/zipformer/optim.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../ASR/zipformer/scaling.py

View File

@ -1,406 +0,0 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Tuple
import torch
from scaling import (
Balancer,
BiasNorm,
Dropout3,
FloatLike,
Optional,
ScaledConv2d,
ScaleGrad,
ScheduledFloat,
SwooshL,
SwooshR,
Whiten,
)
from torch import Tensor, nn
class ConvNeXt(nn.Module):
"""
Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf
"""
def __init__(
self,
channels: int,
hidden_ratio: int = 3,
kernel_size: Tuple[int, int] = (7, 7),
layerdrop_rate: FloatLike = None,
):
super().__init__()
self.padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
hidden_channels = channels * hidden_ratio
if layerdrop_rate is None:
layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015))
self.layerdrop_rate = layerdrop_rate
self.depthwise_conv = nn.Conv2d(
in_channels=channels,
out_channels=channels,
groups=channels,
kernel_size=kernel_size,
padding=self.padding,
)
self.pointwise_conv1 = nn.Conv2d(
in_channels=channels, out_channels=hidden_channels, kernel_size=1
)
self.hidden_balancer = Balancer(
hidden_channels,
channel_dim=1,
min_positive=0.3,
max_positive=1.0,
min_abs=0.75,
max_abs=5.0,
)
self.activation = SwooshL()
self.pointwise_conv2 = ScaledConv2d(
in_channels=hidden_channels,
out_channels=channels,
kernel_size=1,
initial_scale=0.01,
)
self.out_balancer = Balancer(
channels,
channel_dim=1,
min_positive=0.4,
max_positive=0.6,
min_abs=1.0,
max_abs=6.0,
)
self.out_whiten = Whiten(
num_groups=1,
whitening_limit=5.0,
prob=(0.025, 0.25),
grad_scale=0.01,
)
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
return self.forward_internal(x)
layerdrop_rate = float(self.layerdrop_rate)
if layerdrop_rate != 0.0:
batch_size = x.shape[0]
mask = (
torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device)
> layerdrop_rate
)
else:
mask = None
# turns out this caching idea does not work with --world-size > 1
# return caching_eval(self.forward_internal, x, mask)
return self.forward_internal(x, mask)
def forward_internal(
self, x: Tensor, layer_skip_mask: Optional[Tensor] = None
) -> Tensor:
"""
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
The returned value has the same shape as x.
"""
bypass = x
x = self.depthwise_conv(x)
x = self.pointwise_conv1(x)
x = self.hidden_balancer(x)
x = self.activation(x)
x = self.pointwise_conv2(x)
if layer_skip_mask is not None:
x = x * layer_skip_mask
x = bypass + x
x = self.out_balancer(x)
if x.requires_grad:
x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last
x = self.out_whiten(x)
x = x.transpose(1, 3) # (N, C, H, W)
return x
def streaming_forward(
self,
x: Tensor,
cached_left_pad: Tensor,
) -> Tuple[Tensor, Tensor]:
"""
Args:
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
cached_left_pad: (batch_size, num_channels, left_pad, num_freqs)
Returns:
- The returned value has the same shape as x.
- Updated cached_left_pad.
"""
padding = self.padding
# The length without right padding for depth-wise conv
T = x.size(2) - padding[0]
bypass = x[:, :, :T, :]
# Pad left side
assert cached_left_pad.size(2) == padding[0], (
cached_left_pad.size(2),
padding[0],
)
x = torch.cat([cached_left_pad, x], dim=2)
# Update cached left padding
cached_left_pad = x[:, :, T : padding[0] + T, :]
# depthwise_conv
x = torch.nn.functional.conv2d(
x,
weight=self.depthwise_conv.weight,
bias=self.depthwise_conv.bias,
padding=(0, padding[1]),
groups=self.depthwise_conv.groups,
)
x = self.pointwise_conv1(x)
x = self.hidden_balancer(x)
x = self.activation(x)
x = self.pointwise_conv2(x)
x = bypass + x
return x, cached_left_pad
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/2 length).
Convert an input of shape (N, T, idim) to an output
with shape (N, T', odim), where
T' = (T-3)//2 - 2 == (T-7)//2
It is based on
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
"""
def __init__(
self,
in_channels: int,
out_channels: int,
layer1_channels: int = 8,
layer2_channels: int = 32,
layer3_channels: int = 128,
dropout: FloatLike = 0.1,
) -> None:
"""
Args:
in_channels:
Number of channels in. The input shape is (N, T, in_channels).
Caution: It requires: T >=7, in_channels >=7
out_channels
Output dim. The output shape is (N, (T-3)//2, out_channels)
layer1_channels:
Number of channels in layer1
layer1_channels:
Number of channels in layer2
bottleneck:
bottleneck dimension for 1d squeeze-excite
"""
assert in_channels >= 7
super().__init__()
# The ScaleGrad module is there to prevent the gradients
# w.r.t. the weight or bias of the first Conv2d module in self.conv from
# exceeding the range of fp16 when using automatic mixed precision (amp)
# training. (The second one is necessary to stop its bias from getting
# a too-large gradient).
self.conv = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=layer1_channels,
kernel_size=3,
padding=(0, 1), # (time, freq)
),
ScaleGrad(0.2),
Balancer(layer1_channels, channel_dim=1, max_abs=1.0),
SwooshR(),
nn.Conv2d(
in_channels=layer1_channels,
out_channels=layer2_channels,
kernel_size=3,
stride=2,
padding=0,
),
Balancer(layer2_channels, channel_dim=1, max_abs=4.0),
SwooshR(),
nn.Conv2d(
in_channels=layer2_channels,
out_channels=layer3_channels,
kernel_size=3,
stride=(1, 2), # (time, freq)
),
Balancer(layer3_channels, channel_dim=1, max_abs=4.0),
SwooshR(),
)
# just one convnext layer
self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7))
# (in_channels-3)//4
self.out_width = (((in_channels - 1) // 2) - 1) // 2
self.layer3_channels = layer3_channels
self.out = nn.Linear(self.out_width * layer3_channels, out_channels)
# use a larger than normal grad_scale on this whitening module; there is
# only one such module, so there is not a concern about adding together
# many copies of this extra gradient term.
self.out_whiten = Whiten(
num_groups=1,
whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0),
prob=(0.025, 0.25),
grad_scale=0.02,
)
# max_log_eps=0.0 is to prevent both eps and the output of self.out from
# getting large, there is an unnecessary degree of freedom.
self.out_norm = BiasNorm(out_channels)
self.dropout = Dropout3(dropout, shared_dim=1)
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
Returns:
- a tensor of shape (N, (T-7)//2, odim)
- output lengths, of shape (batch_size,)
"""
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
# scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision)
# training, since the weights in the first convolution are otherwise the limiting factor for getting infinite
# gradients.
x = self.conv(x)
x = self.convnext(x)
# Now x is of shape (N, odim, (T-7)//2, (idim-3)//4)
b, c, t, f = x.size()
x = x.transpose(1, 2).reshape(b, t, c * f)
# now x: (N, (T-7)//2, out_width * layer3_channels))
x = self.out(x)
# Now x is of shape (N, (T-7)//2, odim)
x = self.out_whiten(x)
x = self.out_norm(x)
x = self.dropout(x)
if torch.jit.is_scripting() or torch.jit.is_tracing():
x_lens = (x_lens - 7) // 2
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
x_lens = (x_lens - 7) // 2
assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max())
return x, x_lens
def streaming_forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
cached_left_pad: Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
Returns:
- a tensor of shape (N, (T-7)//2, odim)
- output lengths, of shape (batch_size,)
- updated cache
"""
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
# T' = (T-7)//2
x = self.conv(x)
# T' = (T-7)//2-3
x, cached_left_pad = self.convnext.streaming_forward(
x, cached_left_pad=cached_left_pad
)
# Now x is of shape (N, odim, T', ((idim-1)//2 - 1)//2)
b, c, t, f = x.size()
x = x.transpose(1, 2).reshape(b, t, c * f)
# now x: (N, T', out_width * layer3_channels))
x = self.out(x)
# Now x is of shape (N, T', odim)
x = self.out_norm(x)
if torch.jit.is_scripting() or torch.jit.is_tracing():
assert self.convnext.padding[0] == 3
# The ConvNeXt module needs 3 frames of right padding after subsampling
x_lens = (x_lens - 7) // 2 - 3
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# The ConvNeXt module needs 3 frames of right padding after subsampling
assert self.convnext.padding[0] == 3
x_lens = (x_lens - 7) // 2 - 3
assert x.size(1) == x_lens.max().item(), (x.shape, x_lens.max())
return x, x_lens, cached_left_pad
@torch.jit.export
def get_init_states(
self,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> Tensor:
"""Get initial states for Conv2dSubsampling module.
It is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
"""
left_pad = self.convnext.padding[0]
freq = self.out_width
channels = self.layer3_channels
cached_embed_left_pad = torch.zeros(batch_size, channels, left_pad, freq).to(
device
)
return cached_embed_left_pad