modified zipformer

This commit is contained in:
Daniil 2024-10-17 00:09:54 +00:00
parent f84270c935
commit 6072407669
5 changed files with 2143 additions and 4093 deletions

View File

@ -20,115 +20,162 @@ import torch.nn.functional as F
from scaling import Balancer from scaling import Balancer
class Decoder(nn.Module): class Decoder(torch.nn.Module):
"""This class modifies the stateless decoder from the following paper: """
This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
It removes the recurrent connection from the decoder, i.e., the prediction It removes the recurrent connection from the decoder, i.e., the prediction network.
network. Different from the above paper, it adds an extra Conv1d Different from the above paper, it adds an extra Conv1d right after the embedding layer.
right after the embedding layer. """
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf def __init__(
self, vocab_size: int, decoder_dim: int, context_size: int, device: torch.device,
) -> None:
"""
Decoder initialization.
Parameters
----------
vocab_size : int
A number of tokens or modeling units, includes blank.
decoder_dim : int
A dimension of the decoder embeddings, and the decoder output.
context_size : int
A number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
device : torch.device
The device used to store the layer weights. Should be
either torch.device("cpu") or torch.device("cuda").
"""
super().__init__()
self.embedding = torch.nn.Embedding(vocab_size, decoder_dim)
if context_size < 1:
raise ValueError(
'RNN-T decoder context size should be an integer greater '
f'or equal than 1, but got {context_size}.',
)
self.context_size = context_size
self.conv = torch.nn.Conv1d(
decoder_dim,
decoder_dim,
context_size,
groups=decoder_dim // 4,
bias=False,
device=device,
)
def forward(self, y: torch.Tensor) -> torch.Tensor:
"""
Does a forward pass of the stateless Decoder module. Returns an output decoder tensor.
Parameters
----------
y : torch.Tensor[torch.int32]
The input integer tensor of shape (N, context_size).
The module input that corresponds to the last context_size decoded token indexes.
Returns
-------
torch.Tensor[torch.float32]
An output float tensor of shape (N, 1, decoder_dim).
"""
# this stuff about clamp() is a fix for a mismatch at utterance start,
# we use negative ids in RNN-T decoding.
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(2)
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = torch.nn.functional.relu(embedding_out)
return embedding_out
class DecoderModule(torch.nn.Module):
"""
A helper module to combine decoder, decoder projection, and joiner inference together.
""" """
def __init__( def __init__(
self, self,
vocab_size: int, vocab_size: int,
decoder_dim: int, decoder_dim: int,
blank_id: int, joiner_dim: int,
context_size: int, context_size: int,
): beam: int,
device: torch.device,
) -> None:
""" """
Args: DecoderModule initialization.
Parameters
----------
vocab_size: vocab_size:
Number of tokens of the modeling unit including blank. A number of tokens or modeling units, includes blank.
decoder_dim: decoder_dim : int
Dimension of the input embedding, and of the decoder output. A dimension of the decoder embeddings, and the decoder output.
blank_id: joiner_dim : int
The ID of the blank symbol. Input joiner dimension.
context_size: context_size : int
Number of previous words to use to predict the next word. A number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram. 1 means bigram; 2 means trigram. n means (n+1)-gram.
beam : int
A decoder beam.
device : torch.device
The device used to store the layer weights. Should be
either torch.device("cpu") or torch.device("cuda").
""" """
super().__init__() super().__init__()
self.embedding = nn.Embedding( self.decoder = Decoder(vocab_size, decoder_dim, context_size, device)
num_embeddings=vocab_size, self.decoder_proj = torch.nn.Linear(decoder_dim, joiner_dim, device=device)
embedding_dim=decoder_dim, self.joiner = Joiner(joiner_dim, vocab_size, device)
)
# the balancers are to avoid any drift in the magnitude of the
# embeddings, which would interact badly with parameter averaging.
self.balancer = Balancer(
decoder_dim,
channel_dim=-1,
min_positive=0.0,
max_positive=1.0,
min_abs=0.5,
max_abs=1.0,
prob=0.05,
)
self.blank_id = blank_id
assert context_size >= 1, context_size
self.context_size = context_size
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.beam = beam
if context_size > 1: def forward(
self.conv = nn.Conv1d( self, decoder_input: torch.Tensor, encoder_out: torch.Tensor, hyps_log_prob: torch.Tensor,
in_channels=decoder_dim, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
out_channels=decoder_dim,
kernel_size=context_size,
padding=0,
groups=decoder_dim // 4, # group size == 4
bias=False,
)
self.balancer2 = Balancer(
decoder_dim,
channel_dim=-1,
min_positive=0.0,
max_positive=1.0,
min_abs=0.5,
max_abs=1.0,
prob=0.05,
)
else:
# To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
# when inference with torch.jit.script and context_size == 1
self.conv = nn.Identity()
self.balancer2 = nn.Identity()
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
""" """
Args: Does a forward pass of the stateless Decoder module. Returns an output decoder tensor.
y:
A 2-D tensor of shape (N, U). Parameters
need_pad: ----------
True to left pad the input. Should be True during training. decoder_input : torch.Tensor[torch.int32]
False to not pad the input. Should be False during inference. The input integer tensor of shape (num_hyps, context_size).
Returns: The module input that corresponds to the last context_size decoded token indexes.
Return a tensor of shape (N, U, decoder_dim). encoder_out : torch.Tensor[torch.float32]
An output tensor from the encoder after projection of shape (num_hyps, joiner_dim).
hyps_log_prob : torch.Tensor[torch.float32]
Hypothesis probabilities in a logarithmic scale of shape (num_hyps, 1).
Returns
-------
torch.Tensor[torch.float32]
A float output tensor of logit token probabilities of shape (num_hyps, vocab_size).
""" """
y = y.to(torch.int64)
# this stuff about clamp() is a temporary fix for a mismatch
# at utterance start, we use negative ids in beam_search.py
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
embedding_out = self.balancer(embedding_out) decoder_out = self.decoder(decoder_input)
decoder_out = self.decoder_proj(decoder_out)
if self.context_size > 1: logits = self.joiner(encoder_out, decoder_out[:, 0, :])
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = F.relu(embedding_out)
embedding_out = self.balancer2(embedding_out)
return embedding_out tokens_log_prob = torch.log_softmax(logits, dim=1)
log_probs = (tokens_log_prob + hyps_log_prob).reshape(-1)
hyps_topk_log_prob, topk_indexes = log_probs.topk(self.beam)
topk_hyp_indexes = torch.floor_divide(topk_indexes, self.vocab_size).to(torch.int32)
topk_token_indexes = torch.remainder(topk_indexes, self.vocab_size).to(torch.int32)
tokens_topk_prob = torch.exp(tokens_log_prob.reshape(-1)[topk_indexes])
return hyps_topk_log_prob, tokens_topk_prob, topk_hyp_indexes, topk_token_indexes

View File

@ -19,49 +19,42 @@ import torch.nn as nn
from scaling import ScaledLinear from scaling import ScaledLinear
class Joiner(nn.Module): class Joiner(torch.nn.Module):
def __init__(
self, def __init__(self, joiner_dim: int, vocab_size: int, device: torch.device) -> None:
encoder_dim: int, """
decoder_dim: int, Joiner initialization.
joiner_dim: int,
vocab_size: int, Parameters
): ----------
joiner_dim : int
Input joiner dimension.
vocab_size : int
Output joiner dimension, the vocabulary size, the number of BPEs of the model.
device : torch.device
The device used to store the layer weights. Should be
either torch.device("cpu") or torch.device("cuda").
"""
super().__init__() super().__init__()
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25) self.output_linear = torch.nn.Linear(joiner_dim, vocab_size, device=device)
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
self.output_linear = nn.Linear(joiner_dim, vocab_size)
def forward( def forward(self, encoder_out: torch.Tensor, decoder_out: torch.Tensor) -> torch.Tensor:
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
project_input: bool = True,
) -> torch.Tensor:
""" """
Args: Does a forward pass of the Joiner module. Returns an output tensor after a simple joining.
encoder_out:
Output from the encoder. Its shape is (N, T, s_range, C). Parameters
decoder_out: ----------
Output from the decoder. Its shape is (N, T, s_range, C). encoder_out : torch.Tensor[torch.float32]
project_input: An output tensor from the encoder after projection of shape (N, joiner_dim).
If true, apply input projections encoder_proj and decoder_proj. decoder_out : torch.Tensor[torch.float32]
If this is false, it is the user's responsibility to do this An output tensor from the decoder after projection of shape (N, joiner_dim).
manually.
Returns: Returns
Return a tensor of shape (N, T, s_range, C). -------
torch.Tensor[torch.float32]
A float output tensor of log token probabilities of shape (N, vocab_size).
""" """
assert encoder_out.ndim == decoder_out.ndim, (
encoder_out.shape,
decoder_out.shape,
)
if project_input: return self.output_linear(torch.tanh(encoder_out + decoder_out))
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
else:
logit = encoder_out + decoder_out
logit = self.output_linear(torch.tanh(logit))
return logit

File diff suppressed because it is too large Load Diff

View File

@ -36,371 +36,259 @@ from scaling import (
from torch import Tensor, nn from torch import Tensor, nn
class ConvNeXt(nn.Module): class ConvNeXt(torch.nn.Module):
""" """
Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf The simplified ConvNeXt module interpretation based on https://arxiv.org/pdf/2206.14747.pdf.
"""
def __init__(self, num_channels: int, device: torch.device) -> None:
"""
ConvNeXt initialization.
Parameters
----------
num_channels : int
The number of input and output channels for ConvNeXt module.
device : torch.device
The device used to store the layer weights.
Either torch.device("cpu") or torch.device("cuda").
""" """
def __init__(
self,
channels: int,
hidden_ratio: int = 3,
kernel_size: Tuple[int, int] = (7, 7),
layerdrop_rate: FloatLike = None,
):
super().__init__() super().__init__()
self.padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
hidden_channels = channels * hidden_ratio
if layerdrop_rate is None:
layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015))
self.layerdrop_rate = layerdrop_rate
self.depthwise_conv = nn.Conv2d( self.padding = 3
in_channels=channels, hidden_channels = num_channels * 3
out_channels=channels,
groups=channels,
kernel_size=kernel_size,
padding=self.padding,
)
self.pointwise_conv1 = nn.Conv2d( self.depthwise_conv = torch.nn.Conv2d(
in_channels=channels, out_channels=hidden_channels, kernel_size=1 num_channels,
) num_channels,
7,
self.hidden_balancer = Balancer( groups=num_channels,
hidden_channels, padding=(0, self.padding), # time, freq
channel_dim=1, device=device,
min_positive=0.3,
max_positive=1.0,
min_abs=0.75,
max_abs=5.0,
) )
self.activation = SwooshL() self.activation = SwooshL()
self.pointwise_conv2 = ScaledConv2d( self.pointwise_conv1 = torch.nn.Conv2d(num_channels, hidden_channels, 1, device=device)
in_channels=hidden_channels, self.pointwise_conv2 = torch.nn.Conv2d(hidden_channels, num_channels, 1, device=device)
out_channels=channels,
kernel_size=1,
initial_scale=0.01,
)
self.out_balancer = Balancer( def forward(self, x: torch.Tensor) -> torch.Tensor:
channels,
channel_dim=1,
min_positive=0.4,
max_positive=0.6,
min_abs=1.0,
max_abs=6.0,
)
self.out_whiten = Whiten(
num_groups=1,
whitening_limit=5.0,
prob=(0.025, 0.25),
grad_scale=0.01,
)
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
return self.forward_internal(x)
layerdrop_rate = float(self.layerdrop_rate)
if layerdrop_rate != 0.0:
batch_size = x.shape[0]
mask = (
torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device)
> layerdrop_rate
)
else:
mask = None
# turns out this caching idea does not work with --world-size > 1
# return caching_eval(self.forward_internal, x, mask)
return self.forward_internal(x, mask)
def forward_internal(
self, x: Tensor, layer_skip_mask: Optional[Tensor] = None
) -> Tensor:
""" """
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) Does a forward pass of the ConvNeXt module.
The returned value has the same shape as x. Parameters
----------
x : torch.Tensor[torch.float32]
An input float tensor of shape (1, num_channels, num_input_frames, num_freqs).
Returns
-------
torch.Tensor[torch.float32]
A output float tensor of the same shape as input,
(1, num_channels, num_output_frames, num_freqs).
""" """
bypass = x
bypass = x[:, :, self.padding: x.size(2) - self.padding]
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
x = self.pointwise_conv1(x) x = self.pointwise_conv1(x)
x = self.hidden_balancer(x)
x = self.activation(x) x = self.activation(x)
x = self.pointwise_conv2(x) x = self.pointwise_conv2(x)
if layer_skip_mask is not None:
x = x * layer_skip_mask
x = bypass + x x = bypass + x
x = self.out_balancer(x)
if x.requires_grad:
x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last
x = self.out_whiten(x)
x = x.transpose(1, 3) # (N, C, H, W)
return x return x
def streaming_forward(
self, class Conv2dSubsampling(torch.nn.Module):
x: Tensor,
cached_left_pad: Tensor,
) -> Tuple[Tensor, Tensor]:
""" """
Args: Convolutional 2D subsampling module. It performs the prior subsampling
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) (four times subsampling along the frequency axis and two times - along the time axis),
cached_left_pad: (batch_size, num_channels, left_pad, num_freqs) and low-level descriptor feature extraction from the log mel feature input before passing
it to zipformer encoder.
Returns:
- The returned value has the same shape as x.
- Updated cached_left_pad.
"""
padding = self.padding
# The length without right padding for depth-wise conv
T = x.size(2) - padding[0]
bypass = x[:, :, :T, :]
# Pad left side
assert cached_left_pad.size(2) == padding[0], (
cached_left_pad.size(2),
padding[0],
)
x = torch.cat([cached_left_pad, x], dim=2)
# Update cached left padding
cached_left_pad = x[:, :, T : padding[0] + T, :]
# depthwise_conv
x = torch.nn.functional.conv2d(
x,
weight=self.depthwise_conv.weight,
bias=self.depthwise_conv.bias,
padding=(0, padding[1]),
groups=self.depthwise_conv.groups,
)
x = self.pointwise_conv1(x)
x = self.hidden_balancer(x)
x = self.activation(x)
x = self.pointwise_conv2(x)
x = bypass + x
return x, cached_left_pad
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/2 length).
Convert an input of shape (N, T, idim) to an output
with shape (N, T', odim), where
T' = (T-3)//2 - 2 == (T-7)//2
It is based on
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
""" """
def __init__( def __init__(
self, self,
in_channels: int, input_dim: int,
out_channels: int, output_dim: int,
layer1_channels: int = 8, layer1_channels: int,
layer2_channels: int = 32, layer2_channels: int,
layer3_channels: int = 128, layer3_channels: int,
dropout: FloatLike = 0.1, right_context: int,
device: torch.device,
) -> None: ) -> None:
""" """
Args: Conv2dSubsampling initialization.
in_channels:
Number of channels in. The input shape is (N, T, in_channels). Parameters
Caution: It requires: T >=7, in_channels >=7 ----------
out_channels input_dim : int
Output dim. The output shape is (N, (T-3)//2, out_channels) The number of input channels. Corresponds to the
layer1_channels: number of features in the input feature tensor.
Number of channels in layer1 output_dim : int
layer1_channels: The number of output channels.
Number of channels in layer2 layer1_channels : int
bottleneck: The number of output channels in the first Conv2d layer.
bottleneck dimension for 1d squeeze-excite layer2_channels : int
The number of output channels in the second Conv2d layer.
layer3_channels : int
The number of output channels in the third Conv2d layer.
right_context: int
The look-ahead right context that is used to update the left cache.
device : torch.device
The device used to store the layer weights. Should be
either torch.device("cpu") or torch.device("cuda").
""" """
assert in_channels >= 7
super().__init__() super().__init__()
# The ScaleGrad module is there to prevent the gradients if input_dim < 7:
# w.r.t. the weight or bias of the first Conv2d module in self.conv from raise ValueError(
# exceeding the range of fp16 when using automatic mixed precision (amp) 'The input feature dimension of the Conv2dSubsampling layer, can not be less than '
# training. (The second one is necessary to stop its bias from getting 'seven, otherwise the frequency subsampling will result with an empty output. '
# a too-large gradient). f'Expected input_dim to be at least 7 but got {input_dim}.',
)
self.conv = nn.Sequential( self.right_context = right_context
nn.Conv2d(
# Assume batch size is 1 and the right padding is 10,
# see the forward method on why the right padding is 10.
self.right_pad = torch.full(
(1, 10, input_dim), ZERO_LOG_MEL, dtype=torch.float32, device=device,
)
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(
in_channels=1, in_channels=1,
out_channels=layer1_channels, out_channels=layer1_channels,
kernel_size=3, kernel_size=3,
padding=(0, 1), # (time, freq) padding=(0, 1), # (time, freq)
device=device,
), ),
ScaleGrad(0.2),
Balancer(layer1_channels, channel_dim=1, max_abs=1.0),
SwooshR(), SwooshR(),
nn.Conv2d( torch.nn.Conv2d(layer1_channels, layer2_channels, 3, stride=2, device=device),
in_channels=layer1_channels,
out_channels=layer2_channels,
kernel_size=3,
stride=2,
padding=0,
),
Balancer(layer2_channels, channel_dim=1, max_abs=4.0),
SwooshR(), SwooshR(),
nn.Conv2d( torch.nn.Conv2d(layer2_channels, layer3_channels, 3, stride=(1, 2), device=device),
in_channels=layer2_channels,
out_channels=layer3_channels,
kernel_size=3,
stride=(1, 2), # (time, freq)
),
Balancer(layer3_channels, channel_dim=1, max_abs=4.0),
SwooshR(), SwooshR(),
) )
# just one convnext layer self.convnext = ConvNeXt(layer3_channels, device=device)
self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7))
# (in_channels-3)//4 out_width = (((input_dim - 1) // 2) - 1) // 2
self.out_width = (((in_channels - 1) // 2) - 1) // 2 self.out = torch.nn.Linear(out_width * layer3_channels, output_dim, device=device)
self.layer3_channels = layer3_channels self.out_norm = BiasNorm(output_dim, device=device)
self.out = nn.Linear(self.out_width * layer3_channels, out_channels)
# use a larger than normal grad_scale on this whitening module; there is
# only one such module, so there is not a concern about adding together
# many copies of this extra gradient term.
self.out_whiten = Whiten(
num_groups=1,
whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0),
prob=(0.025, 0.25),
grad_scale=0.02,
)
# max_log_eps=0.0 is to prevent both eps and the output of self.out from
# getting large, there is an unnecessary degree of freedom.
self.out_norm = BiasNorm(out_channels)
self.dropout = Dropout3(dropout, shared_dim=1)
def forward( def forward(
self, x: torch.Tensor, x_lens: torch.Tensor self, x: torch.Tensor, cached_left_pad: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
Returns:
- a tensor of shape (N, (T-7)//2, odim)
- output lengths, of shape (batch_size,)
""" """
# On entry, x is (N, T, idim) Does a forward pass of the Conv2dSubsampling module.
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
# scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision) Parameters
# training, since the weights in the first convolution are otherwise the limiting factor for getting infinite ----------
# gradients. x : torch.Tensor[torch.float32]
An input float tensor of shape (1, num_frames, input_dim). An input feature tensor.
cached_left_pad : torch.Tensor[torch.float32]
A left cache float tensor of shape (1, 10, input_dim). Left cache is required
to preserve the "same" left padding to the output of the Conv2dSubsampling module.
See the get_init_states() documentation to understand why we need exactly ten frames
of left padding for the Conv2dSubsampling module.
Returns
-------
tuple[torch.Tensor[torch.float32], torch.Tensor[torch.float32]]
A tuple of two float tensors:
- The processing output of the Conv2dSubsampling module
of shape (1, subsampled_num_frames, output_dim).
- The udated left cache tensor of shape (1, 10, input_dim).
"""
x = torch.cat((cached_left_pad, x), dim=1)
new_cached_left_pad = x[
:,
x.size(1) - self.right_context - cached_left_pad.size(1):
x.size(1) - self.right_context,
]
# Now when we concatenated the left cache with the input, we need to perform the right
# padding of the input in a way to preserve the "same" type of padding, so that the output
# of the module has the same duration as input (taking 2 times subsampling into account).
# There are two possible outcomes depending on whether the the number of input frames is
# even or odd, but both scenarios can be covered by 10 frames right padding.
# x : right padding
# | | | | | | | | | | | |:| | | | | | | | | | input
# | | | | | | | | | | |:| | | | | | | | | first Conv2d output from self.conv
# | | | | | :| | | | second Conv2d output from self.conv
# | | | | :| | | third Conv2d output from self.conv
# | : Conv2d output from
# : self.convnext.depthwise_conv
# :
# x : right padding
# | | | | | | | | | | | | |:| | | | | | | | | | input
# | | | | | | | | | | | |:| | | | | | | | | first Conv2d output from self.conv
# | | | | | |: | | | | second Conv2d output from self.conv
# | | | | |: | | | third Conv2d output from self.conv
# | |: Conv2d output from
# : self.convnext.depthwise_conv
# :
x = torch.cat((x, self.right_pad), dim=1)
# (1, T, input_dim) -> (1, 1, T, input_dim) i.e., (N, C, H, W)
x = x.unsqueeze(1)
x = self.conv(x) x = self.conv(x)
x = self.convnext(x) x = self.convnext(x)
# Now x is of shape (N, odim, (T-7)//2, (idim-3)//4) # Now x is of shape (1, output_dim, T', ((input_dim - 1) // 2 - 1) // 2)
b, c, t, f = x.size() b, c, t, f = x.size() # b is equal to 1
x = x.permute(0, 2, 1, 3).reshape(b, t, c * f)
x = x.transpose(1, 2).reshape(b, t, c * f) # Now x is of shape (T', output_dim * layer3_channels))
# now x: (N, (T-7)//2, out_width * layer3_channels))
x = self.out(x) x = self.out(x)
# Now x is of shape (N, (T-7)//2, odim) # Now x is of shape (T', output_dim)
x = self.out_whiten(x)
x = self.out_norm(x)
x = self.dropout(x)
if torch.jit.is_scripting() or torch.jit.is_tracing():
x_lens = (x_lens - 7) // 2
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
x_lens = (x_lens - 7) // 2
assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max())
return x, x_lens
def streaming_forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
cached_left_pad: Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
Returns:
- a tensor of shape (N, (T-7)//2, odim)
- output lengths, of shape (batch_size,)
- updated cache
"""
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
# T' = (T-7)//2
x = self.conv(x)
# T' = (T-7)//2-3
x, cached_left_pad = self.convnext.streaming_forward(
x, cached_left_pad=cached_left_pad
)
# Now x is of shape (N, odim, T', ((idim-1)//2 - 1)//2)
b, c, t, f = x.size()
x = x.transpose(1, 2).reshape(b, t, c * f)
# now x: (N, T', out_width * layer3_channels))
x = self.out(x)
# Now x is of shape (N, T', odim)
x = self.out_norm(x) x = self.out_norm(x)
if torch.jit.is_scripting() or torch.jit.is_tracing(): return x, new_cached_left_pad
assert self.convnext.padding[0] == 3
# The ConvNeXt module needs 3 frames of right padding after subsampling
x_lens = (x_lens - 7) // 2 - 3
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# The ConvNeXt module needs 3 frames of right padding after subsampling
assert self.convnext.padding[0] == 3
x_lens = (x_lens - 7) // 2 - 3
assert x.size(1) == x_lens.max().item(), (x.shape, x_lens.max()) def get_init_states(input_dim: int, device: torch.device) -> torch.Tensor:
return x, x_lens, cached_left_pad
@torch.jit.export
def get_init_states(
self,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> Tensor:
"""Get initial states for Conv2dSubsampling module.
It is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
""" """
left_pad = self.convnext.padding[0] Get initial states for Conv2dSubsampling module. The Conv2dSubsampling.conv consists of three
freq = self.out_width consecutive Conv2d layers with the kernel size 3 and no padding, also the middle Conv2d
channels = self.layer3_channels has a stride 2, while the rest have the default stride 1. We want to pad the input from the
cached_embed_left_pad = torch.zeros(batch_size, channels, left_pad, freq).to( left side with cached_left_pad in the "same" way, so when we pass it through
device the Conv2dSubsampling.conv and Conv2dSubsampling.convnext we end up with exactly zero padding
frames from the left.
cached_left_pad : x
| | | | | | | | | |:| | | | | | | | | | | input
| | | | | | | | |:| | | | | | | | | | | first Conv2d output from Conv2dSubsampling.conv
| | | | :| | | | | | ... second Conv2d output from Conv2dSubsampling.conv
| | | :| | | | | | third Conv2d output from Conv2dSubsampling.conv
:| | | | | | Conv2d output from
: Conv2dSubsampling.convnext.depthwise_conv
As we can see from the picture above, in order to preserve the "same"
padding from the left side we need
((((pad - 1) - 1) // 2) - 1) - 3 = 0 --> pad = 10.
Parameters
----------
input_dim : int
The number of input channels.
Corresponds to the number of features in the input of the Conv2dSubsampling module.
device : torch.device
The device used to store the left cache tensor.
Either torch.device("cpu") or torch.device("cuda").
Returns
-------
torch.Tensor[torch.float32]
A left cache float tensor. The output shape is (1, 10, input_dim).
"""
pad = 10
cached_left_pad = torch.full(
(1, pad, input_dim), ZERO_LOG_MEL, dtype=torch.float32, device=device,
) )
return cached_embed_left_pad return cached_left_pad

File diff suppressed because it is too large Load Diff