modified zipformer

This commit is contained in:
Daniil 2024-10-17 00:09:54 +00:00
parent f84270c935
commit 6072407669
5 changed files with 2143 additions and 4093 deletions

View File

@ -20,115 +20,162 @@ import torch.nn.functional as F
from scaling import Balancer from scaling import Balancer
class Decoder(nn.Module): class Decoder(torch.nn.Module):
"""This class modifies the stateless decoder from the following paper: """
This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
RNN-transducer with stateless prediction network It removes the recurrent connection from the decoder, i.e., the prediction network.
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 Different from the above paper, it adds an extra Conv1d right after the embedding layer.
"""
It removes the recurrent connection from the decoder, i.e., the prediction def __init__(
network. Different from the above paper, it adds an extra Conv1d self, vocab_size: int, decoder_dim: int, context_size: int, device: torch.device,
right after the embedding layer. ) -> None:
"""
Decoder initialization.
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf Parameters
----------
vocab_size : int
A number of tokens or modeling units, includes blank.
decoder_dim : int
A dimension of the decoder embeddings, and the decoder output.
context_size : int
A number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
device : torch.device
The device used to store the layer weights. Should be
either torch.device("cpu") or torch.device("cuda").
"""
super().__init__()
self.embedding = torch.nn.Embedding(vocab_size, decoder_dim)
if context_size < 1:
raise ValueError(
'RNN-T decoder context size should be an integer greater '
f'or equal than 1, but got {context_size}.',
)
self.context_size = context_size
self.conv = torch.nn.Conv1d(
decoder_dim,
decoder_dim,
context_size,
groups=decoder_dim // 4,
bias=False,
device=device,
)
def forward(self, y: torch.Tensor) -> torch.Tensor:
"""
Does a forward pass of the stateless Decoder module. Returns an output decoder tensor.
Parameters
----------
y : torch.Tensor[torch.int32]
The input integer tensor of shape (N, context_size).
The module input that corresponds to the last context_size decoded token indexes.
Returns
-------
torch.Tensor[torch.float32]
An output float tensor of shape (N, 1, decoder_dim).
"""
# this stuff about clamp() is a fix for a mismatch at utterance start,
# we use negative ids in RNN-T decoding.
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(2)
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = torch.nn.functional.relu(embedding_out)
return embedding_out
class DecoderModule(torch.nn.Module):
"""
A helper module to combine decoder, decoder projection, and joiner inference together.
""" """
def __init__( def __init__(
self, self,
vocab_size: int, vocab_size: int,
decoder_dim: int, decoder_dim: int,
blank_id: int, joiner_dim: int,
context_size: int, context_size: int,
): beam: int,
device: torch.device,
) -> None:
""" """
Args: DecoderModule initialization.
vocab_size:
Number of tokens of the modeling unit including blank. Parameters
decoder_dim: ----------
Dimension of the input embedding, and of the decoder output. vocab_size:
blank_id: A number of tokens or modeling units, includes blank.
The ID of the blank symbol. decoder_dim : int
context_size: A dimension of the decoder embeddings, and the decoder output.
Number of previous words to use to predict the next word. joiner_dim : int
Input joiner dimension.
context_size : int
A number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram. 1 means bigram; 2 means trigram. n means (n+1)-gram.
beam : int
A decoder beam.
device : torch.device
The device used to store the layer weights. Should be
either torch.device("cpu") or torch.device("cuda").
""" """
super().__init__() super().__init__()
self.embedding = nn.Embedding( self.decoder = Decoder(vocab_size, decoder_dim, context_size, device)
num_embeddings=vocab_size, self.decoder_proj = torch.nn.Linear(decoder_dim, joiner_dim, device=device)
embedding_dim=decoder_dim, self.joiner = Joiner(joiner_dim, vocab_size, device)
)
# the balancers are to avoid any drift in the magnitude of the
# embeddings, which would interact badly with parameter averaging.
self.balancer = Balancer(
decoder_dim,
channel_dim=-1,
min_positive=0.0,
max_positive=1.0,
min_abs=0.5,
max_abs=1.0,
prob=0.05,
)
self.blank_id = blank_id
assert context_size >= 1, context_size
self.context_size = context_size
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.beam = beam
if context_size > 1: def forward(
self.conv = nn.Conv1d( self, decoder_input: torch.Tensor, encoder_out: torch.Tensor, hyps_log_prob: torch.Tensor,
in_channels=decoder_dim, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
out_channels=decoder_dim,
kernel_size=context_size,
padding=0,
groups=decoder_dim // 4, # group size == 4
bias=False,
)
self.balancer2 = Balancer(
decoder_dim,
channel_dim=-1,
min_positive=0.0,
max_positive=1.0,
min_abs=0.5,
max_abs=1.0,
prob=0.05,
)
else:
# To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
# when inference with torch.jit.script and context_size == 1
self.conv = nn.Identity()
self.balancer2 = nn.Identity()
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
""" """
Args: Does a forward pass of the stateless Decoder module. Returns an output decoder tensor.
y:
A 2-D tensor of shape (N, U). Parameters
need_pad: ----------
True to left pad the input. Should be True during training. decoder_input : torch.Tensor[torch.int32]
False to not pad the input. Should be False during inference. The input integer tensor of shape (num_hyps, context_size).
Returns: The module input that corresponds to the last context_size decoded token indexes.
Return a tensor of shape (N, U, decoder_dim). encoder_out : torch.Tensor[torch.float32]
An output tensor from the encoder after projection of shape (num_hyps, joiner_dim).
hyps_log_prob : torch.Tensor[torch.float32]
Hypothesis probabilities in a logarithmic scale of shape (num_hyps, 1).
Returns
-------
torch.Tensor[torch.float32]
A float output tensor of logit token probabilities of shape (num_hyps, vocab_size).
""" """
y = y.to(torch.int64)
# this stuff about clamp() is a temporary fix for a mismatch
# at utterance start, we use negative ids in beam_search.py
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
embedding_out = self.balancer(embedding_out) decoder_out = self.decoder(decoder_input)
decoder_out = self.decoder_proj(decoder_out)
if self.context_size > 1: logits = self.joiner(encoder_out, decoder_out[:, 0, :])
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = F.relu(embedding_out)
embedding_out = self.balancer2(embedding_out)
return embedding_out tokens_log_prob = torch.log_softmax(logits, dim=1)
log_probs = (tokens_log_prob + hyps_log_prob).reshape(-1)
hyps_topk_log_prob, topk_indexes = log_probs.topk(self.beam)
topk_hyp_indexes = torch.floor_divide(topk_indexes, self.vocab_size).to(torch.int32)
topk_token_indexes = torch.remainder(topk_indexes, self.vocab_size).to(torch.int32)
tokens_topk_prob = torch.exp(tokens_log_prob.reshape(-1)[topk_indexes])
return hyps_topk_log_prob, tokens_topk_prob, topk_hyp_indexes, topk_token_indexes

View File

@ -19,49 +19,42 @@ import torch.nn as nn
from scaling import ScaledLinear from scaling import ScaledLinear
class Joiner(nn.Module): class Joiner(torch.nn.Module):
def __init__(
self, def __init__(self, joiner_dim: int, vocab_size: int, device: torch.device) -> None:
encoder_dim: int, """
decoder_dim: int, Joiner initialization.
joiner_dim: int,
vocab_size: int, Parameters
): ----------
joiner_dim : int
Input joiner dimension.
vocab_size : int
Output joiner dimension, the vocabulary size, the number of BPEs of the model.
device : torch.device
The device used to store the layer weights. Should be
either torch.device("cpu") or torch.device("cuda").
"""
super().__init__() super().__init__()
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25) self.output_linear = torch.nn.Linear(joiner_dim, vocab_size, device=device)
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
self.output_linear = nn.Linear(joiner_dim, vocab_size)
def forward( def forward(self, encoder_out: torch.Tensor, decoder_out: torch.Tensor) -> torch.Tensor:
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
project_input: bool = True,
) -> torch.Tensor:
""" """
Args: Does a forward pass of the Joiner module. Returns an output tensor after a simple joining.
encoder_out:
Output from the encoder. Its shape is (N, T, s_range, C). Parameters
decoder_out: ----------
Output from the decoder. Its shape is (N, T, s_range, C). encoder_out : torch.Tensor[torch.float32]
project_input: An output tensor from the encoder after projection of shape (N, joiner_dim).
If true, apply input projections encoder_proj and decoder_proj. decoder_out : torch.Tensor[torch.float32]
If this is false, it is the user's responsibility to do this An output tensor from the decoder after projection of shape (N, joiner_dim).
manually.
Returns: Returns
Return a tensor of shape (N, T, s_range, C). -------
torch.Tensor[torch.float32]
A float output tensor of log token probabilities of shape (N, vocab_size).
""" """
assert encoder_out.ndim == decoder_out.ndim, (
encoder_out.shape,
decoder_out.shape,
)
if project_input: return self.output_linear(torch.tanh(encoder_out + decoder_out))
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
else:
logit = encoder_out + decoder_out
logit = self.output_linear(torch.tanh(logit))
return logit

File diff suppressed because it is too large Load Diff

View File

@ -36,371 +36,259 @@ from scaling import (
from torch import Tensor, nn from torch import Tensor, nn
class ConvNeXt(nn.Module): class ConvNeXt(torch.nn.Module):
""" """
Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf The simplified ConvNeXt module interpretation based on https://arxiv.org/pdf/2206.14747.pdf.
""" """
def __init__( def __init__(self, num_channels: int, device: torch.device) -> None:
self, """
channels: int, ConvNeXt initialization.
hidden_ratio: int = 3,
kernel_size: Tuple[int, int] = (7, 7), Parameters
layerdrop_rate: FloatLike = None, ----------
): num_channels : int
The number of input and output channels for ConvNeXt module.
device : torch.device
The device used to store the layer weights.
Either torch.device("cpu") or torch.device("cuda").
"""
super().__init__() super().__init__()
self.padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
hidden_channels = channels * hidden_ratio
if layerdrop_rate is None:
layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015))
self.layerdrop_rate = layerdrop_rate
self.depthwise_conv = nn.Conv2d( self.padding = 3
in_channels=channels, hidden_channels = num_channels * 3
out_channels=channels,
groups=channels,
kernel_size=kernel_size,
padding=self.padding,
)
self.pointwise_conv1 = nn.Conv2d( self.depthwise_conv = torch.nn.Conv2d(
in_channels=channels, out_channels=hidden_channels, kernel_size=1 num_channels,
) num_channels,
7,
self.hidden_balancer = Balancer( groups=num_channels,
hidden_channels, padding=(0, self.padding), # time, freq
channel_dim=1, device=device,
min_positive=0.3,
max_positive=1.0,
min_abs=0.75,
max_abs=5.0,
) )
self.activation = SwooshL() self.activation = SwooshL()
self.pointwise_conv2 = ScaledConv2d( self.pointwise_conv1 = torch.nn.Conv2d(num_channels, hidden_channels, 1, device=device)
in_channels=hidden_channels, self.pointwise_conv2 = torch.nn.Conv2d(hidden_channels, num_channels, 1, device=device)
out_channels=channels,
kernel_size=1,
initial_scale=0.01,
)
self.out_balancer = Balancer( def forward(self, x: torch.Tensor) -> torch.Tensor:
channels,
channel_dim=1,
min_positive=0.4,
max_positive=0.6,
min_abs=1.0,
max_abs=6.0,
)
self.out_whiten = Whiten(
num_groups=1,
whitening_limit=5.0,
prob=(0.025, 0.25),
grad_scale=0.01,
)
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
return self.forward_internal(x)
layerdrop_rate = float(self.layerdrop_rate)
if layerdrop_rate != 0.0:
batch_size = x.shape[0]
mask = (
torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device)
> layerdrop_rate
)
else:
mask = None
# turns out this caching idea does not work with --world-size > 1
# return caching_eval(self.forward_internal, x, mask)
return self.forward_internal(x, mask)
def forward_internal(
self, x: Tensor, layer_skip_mask: Optional[Tensor] = None
) -> Tensor:
""" """
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) Does a forward pass of the ConvNeXt module.
The returned value has the same shape as x. Parameters
----------
x : torch.Tensor[torch.float32]
An input float tensor of shape (1, num_channels, num_input_frames, num_freqs).
Returns
-------
torch.Tensor[torch.float32]
A output float tensor of the same shape as input,
(1, num_channels, num_output_frames, num_freqs).
""" """
bypass = x
bypass = x[:, :, self.padding: x.size(2) - self.padding]
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
x = self.pointwise_conv1(x) x = self.pointwise_conv1(x)
x = self.hidden_balancer(x)
x = self.activation(x) x = self.activation(x)
x = self.pointwise_conv2(x) x = self.pointwise_conv2(x)
if layer_skip_mask is not None:
x = x * layer_skip_mask
x = bypass + x x = bypass + x
x = self.out_balancer(x)
if x.requires_grad:
x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last
x = self.out_whiten(x)
x = x.transpose(1, 3) # (N, C, H, W)
return x return x
def streaming_forward(
self,
x: Tensor,
cached_left_pad: Tensor,
) -> Tuple[Tensor, Tensor]:
"""
Args:
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
cached_left_pad: (batch_size, num_channels, left_pad, num_freqs)
Returns: class Conv2dSubsampling(torch.nn.Module):
- The returned value has the same shape as x. """
- Updated cached_left_pad. Convolutional 2D subsampling module. It performs the prior subsampling
""" (four times subsampling along the frequency axis and two times - along the time axis),
padding = self.padding and low-level descriptor feature extraction from the log mel feature input before passing
it to zipformer encoder.
# The length without right padding for depth-wise conv
T = x.size(2) - padding[0]
bypass = x[:, :, :T, :]
# Pad left side
assert cached_left_pad.size(2) == padding[0], (
cached_left_pad.size(2),
padding[0],
)
x = torch.cat([cached_left_pad, x], dim=2)
# Update cached left padding
cached_left_pad = x[:, :, T : padding[0] + T, :]
# depthwise_conv
x = torch.nn.functional.conv2d(
x,
weight=self.depthwise_conv.weight,
bias=self.depthwise_conv.bias,
padding=(0, padding[1]),
groups=self.depthwise_conv.groups,
)
x = self.pointwise_conv1(x)
x = self.hidden_balancer(x)
x = self.activation(x)
x = self.pointwise_conv2(x)
x = bypass + x
return x, cached_left_pad
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/2 length).
Convert an input of shape (N, T, idim) to an output
with shape (N, T', odim), where
T' = (T-3)//2 - 2 == (T-7)//2
It is based on
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
""" """
def __init__( def __init__(
self, self,
in_channels: int, input_dim: int,
out_channels: int, output_dim: int,
layer1_channels: int = 8, layer1_channels: int,
layer2_channels: int = 32, layer2_channels: int,
layer3_channels: int = 128, layer3_channels: int,
dropout: FloatLike = 0.1, right_context: int,
device: torch.device,
) -> None: ) -> None:
""" """
Args: Conv2dSubsampling initialization.
in_channels:
Number of channels in. The input shape is (N, T, in_channels). Parameters
Caution: It requires: T >=7, in_channels >=7 ----------
out_channels input_dim : int
Output dim. The output shape is (N, (T-3)//2, out_channels) The number of input channels. Corresponds to the
layer1_channels: number of features in the input feature tensor.
Number of channels in layer1 output_dim : int
layer1_channels: The number of output channels.
Number of channels in layer2 layer1_channels : int
bottleneck: The number of output channels in the first Conv2d layer.
bottleneck dimension for 1d squeeze-excite layer2_channels : int
The number of output channels in the second Conv2d layer.
layer3_channels : int
The number of output channels in the third Conv2d layer.
right_context: int
The look-ahead right context that is used to update the left cache.
device : torch.device
The device used to store the layer weights. Should be
either torch.device("cpu") or torch.device("cuda").
""" """
assert in_channels >= 7
super().__init__() super().__init__()
# The ScaleGrad module is there to prevent the gradients if input_dim < 7:
# w.r.t. the weight or bias of the first Conv2d module in self.conv from raise ValueError(
# exceeding the range of fp16 when using automatic mixed precision (amp) 'The input feature dimension of the Conv2dSubsampling layer, can not be less than '
# training. (The second one is necessary to stop its bias from getting 'seven, otherwise the frequency subsampling will result with an empty output. '
# a too-large gradient). f'Expected input_dim to be at least 7 but got {input_dim}.',
)
self.conv = nn.Sequential( self.right_context = right_context
nn.Conv2d(
# Assume batch size is 1 and the right padding is 10,
# see the forward method on why the right padding is 10.
self.right_pad = torch.full(
(1, 10, input_dim), ZERO_LOG_MEL, dtype=torch.float32, device=device,
)
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(
in_channels=1, in_channels=1,
out_channels=layer1_channels, out_channels=layer1_channels,
kernel_size=3, kernel_size=3,
padding=(0, 1), # (time, freq) padding=(0, 1), # (time, freq)
device=device,
), ),
ScaleGrad(0.2),
Balancer(layer1_channels, channel_dim=1, max_abs=1.0),
SwooshR(), SwooshR(),
nn.Conv2d( torch.nn.Conv2d(layer1_channels, layer2_channels, 3, stride=2, device=device),
in_channels=layer1_channels,
out_channels=layer2_channels,
kernel_size=3,
stride=2,
padding=0,
),
Balancer(layer2_channels, channel_dim=1, max_abs=4.0),
SwooshR(), SwooshR(),
nn.Conv2d( torch.nn.Conv2d(layer2_channels, layer3_channels, 3, stride=(1, 2), device=device),
in_channels=layer2_channels,
out_channels=layer3_channels,
kernel_size=3,
stride=(1, 2), # (time, freq)
),
Balancer(layer3_channels, channel_dim=1, max_abs=4.0),
SwooshR(), SwooshR(),
) )
# just one convnext layer self.convnext = ConvNeXt(layer3_channels, device=device)
self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7))
# (in_channels-3)//4 out_width = (((input_dim - 1) // 2) - 1) // 2
self.out_width = (((in_channels - 1) // 2) - 1) // 2 self.out = torch.nn.Linear(out_width * layer3_channels, output_dim, device=device)
self.layer3_channels = layer3_channels self.out_norm = BiasNorm(output_dim, device=device)
self.out = nn.Linear(self.out_width * layer3_channels, out_channels)
# use a larger than normal grad_scale on this whitening module; there is
# only one such module, so there is not a concern about adding together
# many copies of this extra gradient term.
self.out_whiten = Whiten(
num_groups=1,
whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0),
prob=(0.025, 0.25),
grad_scale=0.02,
)
# max_log_eps=0.0 is to prevent both eps and the output of self.out from
# getting large, there is an unnecessary degree of freedom.
self.out_norm = BiasNorm(out_channels)
self.dropout = Dropout3(dropout, shared_dim=1)
def forward( def forward(
self, x: torch.Tensor, x_lens: torch.Tensor self, x: torch.Tensor, cached_left_pad: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
Returns:
- a tensor of shape (N, (T-7)//2, odim)
- output lengths, of shape (batch_size,)
""" """
# On entry, x is (N, T, idim) Does a forward pass of the Conv2dSubsampling module.
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
# scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision) Parameters
# training, since the weights in the first convolution are otherwise the limiting factor for getting infinite ----------
# gradients. x : torch.Tensor[torch.float32]
An input float tensor of shape (1, num_frames, input_dim). An input feature tensor.
cached_left_pad : torch.Tensor[torch.float32]
A left cache float tensor of shape (1, 10, input_dim). Left cache is required
to preserve the "same" left padding to the output of the Conv2dSubsampling module.
See the get_init_states() documentation to understand why we need exactly ten frames
of left padding for the Conv2dSubsampling module.
Returns
-------
tuple[torch.Tensor[torch.float32], torch.Tensor[torch.float32]]
A tuple of two float tensors:
- The processing output of the Conv2dSubsampling module
of shape (1, subsampled_num_frames, output_dim).
- The udated left cache tensor of shape (1, 10, input_dim).
"""
x = torch.cat((cached_left_pad, x), dim=1)
new_cached_left_pad = x[
:,
x.size(1) - self.right_context - cached_left_pad.size(1):
x.size(1) - self.right_context,
]
# Now when we concatenated the left cache with the input, we need to perform the right
# padding of the input in a way to preserve the "same" type of padding, so that the output
# of the module has the same duration as input (taking 2 times subsampling into account).
# There are two possible outcomes depending on whether the the number of input frames is
# even or odd, but both scenarios can be covered by 10 frames right padding.
# x : right padding
# | | | | | | | | | | | |:| | | | | | | | | | input
# | | | | | | | | | | |:| | | | | | | | | first Conv2d output from self.conv
# | | | | | :| | | | second Conv2d output from self.conv
# | | | | :| | | third Conv2d output from self.conv
# | : Conv2d output from
# : self.convnext.depthwise_conv
# :
# x : right padding
# | | | | | | | | | | | | |:| | | | | | | | | | input
# | | | | | | | | | | | |:| | | | | | | | | first Conv2d output from self.conv
# | | | | | |: | | | | second Conv2d output from self.conv
# | | | | |: | | | third Conv2d output from self.conv
# | |: Conv2d output from
# : self.convnext.depthwise_conv
# :
x = torch.cat((x, self.right_pad), dim=1)
# (1, T, input_dim) -> (1, 1, T, input_dim) i.e., (N, C, H, W)
x = x.unsqueeze(1)
x = self.conv(x) x = self.conv(x)
x = self.convnext(x) x = self.convnext(x)
# Now x is of shape (N, odim, (T-7)//2, (idim-3)//4) # Now x is of shape (1, output_dim, T', ((input_dim - 1) // 2 - 1) // 2)
b, c, t, f = x.size() b, c, t, f = x.size() # b is equal to 1
x = x.permute(0, 2, 1, 3).reshape(b, t, c * f)
x = x.transpose(1, 2).reshape(b, t, c * f) # Now x is of shape (T', output_dim * layer3_channels))
# now x: (N, (T-7)//2, out_width * layer3_channels))
x = self.out(x) x = self.out(x)
# Now x is of shape (N, (T-7)//2, odim) # Now x is of shape (T', output_dim)
x = self.out_whiten(x)
x = self.out_norm(x)
x = self.dropout(x)
if torch.jit.is_scripting() or torch.jit.is_tracing():
x_lens = (x_lens - 7) // 2
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
x_lens = (x_lens - 7) // 2
assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max())
return x, x_lens
def streaming_forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
cached_left_pad: Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
Returns:
- a tensor of shape (N, (T-7)//2, odim)
- output lengths, of shape (batch_size,)
- updated cache
"""
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
# T' = (T-7)//2
x = self.conv(x)
# T' = (T-7)//2-3
x, cached_left_pad = self.convnext.streaming_forward(
x, cached_left_pad=cached_left_pad
)
# Now x is of shape (N, odim, T', ((idim-1)//2 - 1)//2)
b, c, t, f = x.size()
x = x.transpose(1, 2).reshape(b, t, c * f)
# now x: (N, T', out_width * layer3_channels))
x = self.out(x)
# Now x is of shape (N, T', odim)
x = self.out_norm(x) x = self.out_norm(x)
if torch.jit.is_scripting() or torch.jit.is_tracing(): return x, new_cached_left_pad
assert self.convnext.padding[0] == 3
# The ConvNeXt module needs 3 frames of right padding after subsampling
x_lens = (x_lens - 7) // 2 - 3
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# The ConvNeXt module needs 3 frames of right padding after subsampling
assert self.convnext.padding[0] == 3
x_lens = (x_lens - 7) // 2 - 3
assert x.size(1) == x_lens.max().item(), (x.shape, x_lens.max()) def get_init_states(input_dim: int, device: torch.device) -> torch.Tensor:
"""
Get initial states for Conv2dSubsampling module. The Conv2dSubsampling.conv consists of three
consecutive Conv2d layers with the kernel size 3 and no padding, also the middle Conv2d
has a stride 2, while the rest have the default stride 1. We want to pad the input from the
left side with cached_left_pad in the "same" way, so when we pass it through
the Conv2dSubsampling.conv and Conv2dSubsampling.convnext we end up with exactly zero padding
frames from the left.
return x, x_lens, cached_left_pad cached_left_pad : x
| | | | | | | | | |:| | | | | | | | | | | input
| | | | | | | | |:| | | | | | | | | | | first Conv2d output from Conv2dSubsampling.conv
| | | | :| | | | | | ... second Conv2d output from Conv2dSubsampling.conv
| | | :| | | | | | third Conv2d output from Conv2dSubsampling.conv
:| | | | | | Conv2d output from
: Conv2dSubsampling.convnext.depthwise_conv
@torch.jit.export As we can see from the picture above, in order to preserve the "same"
def get_init_states( padding from the left side we need
self, ((((pad - 1) - 1) // 2) - 1) - 3 = 0 --> pad = 10.
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> Tensor:
"""Get initial states for Conv2dSubsampling module.
It is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
"""
left_pad = self.convnext.padding[0]
freq = self.out_width
channels = self.layer3_channels
cached_embed_left_pad = torch.zeros(batch_size, channels, left_pad, freq).to(
device
)
return cached_embed_left_pad Parameters
----------
input_dim : int
The number of input channels.
Corresponds to the number of features in the input of the Conv2dSubsampling module.
device : torch.device
The device used to store the left cache tensor.
Either torch.device("cpu") or torch.device("cuda").
Returns
-------
torch.Tensor[torch.float32]
A left cache float tensor. The output shape is (1, 10, input_dim).
"""
pad = 10
cached_left_pad = torch.full(
(1, pad, input_dim), ZERO_LOG_MEL, dtype=torch.float32, device=device,
)
return cached_left_pad

File diff suppressed because it is too large Load Diff