2024-10-17 00:09:54 +00:00

295 lines
11 KiB
Python

#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Tuple
import torch
from scaling import (
Balancer,
BiasNorm,
Dropout3,
FloatLike,
Optional,
ScaledConv2d,
ScaleGrad,
ScheduledFloat,
SwooshL,
SwooshR,
Whiten,
)
from torch import Tensor, nn
class ConvNeXt(torch.nn.Module):
"""
The simplified ConvNeXt module interpretation based on https://arxiv.org/pdf/2206.14747.pdf.
"""
def __init__(self, num_channels: int, device: torch.device) -> None:
"""
ConvNeXt initialization.
Parameters
----------
num_channels : int
The number of input and output channels for ConvNeXt module.
device : torch.device
The device used to store the layer weights.
Either torch.device("cpu") or torch.device("cuda").
"""
super().__init__()
self.padding = 3
hidden_channels = num_channels * 3
self.depthwise_conv = torch.nn.Conv2d(
num_channels,
num_channels,
7,
groups=num_channels,
padding=(0, self.padding), # time, freq
device=device,
)
self.activation = SwooshL()
self.pointwise_conv1 = torch.nn.Conv2d(num_channels, hidden_channels, 1, device=device)
self.pointwise_conv2 = torch.nn.Conv2d(hidden_channels, num_channels, 1, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Does a forward pass of the ConvNeXt module.
Parameters
----------
x : torch.Tensor[torch.float32]
An input float tensor of shape (1, num_channels, num_input_frames, num_freqs).
Returns
-------
torch.Tensor[torch.float32]
A output float tensor of the same shape as input,
(1, num_channels, num_output_frames, num_freqs).
"""
bypass = x[:, :, self.padding: x.size(2) - self.padding]
x = self.depthwise_conv(x)
x = self.pointwise_conv1(x)
x = self.activation(x)
x = self.pointwise_conv2(x)
x = bypass + x
return x
class Conv2dSubsampling(torch.nn.Module):
"""
Convolutional 2D subsampling module. It performs the prior subsampling
(four times subsampling along the frequency axis and two times - along the time axis),
and low-level descriptor feature extraction from the log mel feature input before passing
it to zipformer encoder.
"""
def __init__(
self,
input_dim: int,
output_dim: int,
layer1_channels: int,
layer2_channels: int,
layer3_channels: int,
right_context: int,
device: torch.device,
) -> None:
"""
Conv2dSubsampling initialization.
Parameters
----------
input_dim : int
The number of input channels. Corresponds to the
number of features in the input feature tensor.
output_dim : int
The number of output channels.
layer1_channels : int
The number of output channels in the first Conv2d layer.
layer2_channels : int
The number of output channels in the second Conv2d layer.
layer3_channels : int
The number of output channels in the third Conv2d layer.
right_context: int
The look-ahead right context that is used to update the left cache.
device : torch.device
The device used to store the layer weights. Should be
either torch.device("cpu") or torch.device("cuda").
"""
super().__init__()
if input_dim < 7:
raise ValueError(
'The input feature dimension of the Conv2dSubsampling layer, can not be less than '
'seven, otherwise the frequency subsampling will result with an empty output. '
f'Expected input_dim to be at least 7 but got {input_dim}.',
)
self.right_context = right_context
# Assume batch size is 1 and the right padding is 10,
# see the forward method on why the right padding is 10.
self.right_pad = torch.full(
(1, 10, input_dim), ZERO_LOG_MEL, dtype=torch.float32, device=device,
)
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(
in_channels=1,
out_channels=layer1_channels,
kernel_size=3,
padding=(0, 1), # (time, freq)
device=device,
),
SwooshR(),
torch.nn.Conv2d(layer1_channels, layer2_channels, 3, stride=2, device=device),
SwooshR(),
torch.nn.Conv2d(layer2_channels, layer3_channels, 3, stride=(1, 2), device=device),
SwooshR(),
)
self.convnext = ConvNeXt(layer3_channels, device=device)
out_width = (((input_dim - 1) // 2) - 1) // 2
self.out = torch.nn.Linear(out_width * layer3_channels, output_dim, device=device)
self.out_norm = BiasNorm(output_dim, device=device)
def forward(
self, x: torch.Tensor, cached_left_pad: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Does a forward pass of the Conv2dSubsampling module.
Parameters
----------
x : torch.Tensor[torch.float32]
An input float tensor of shape (1, num_frames, input_dim). An input feature tensor.
cached_left_pad : torch.Tensor[torch.float32]
A left cache float tensor of shape (1, 10, input_dim). Left cache is required
to preserve the "same" left padding to the output of the Conv2dSubsampling module.
See the get_init_states() documentation to understand why we need exactly ten frames
of left padding for the Conv2dSubsampling module.
Returns
-------
tuple[torch.Tensor[torch.float32], torch.Tensor[torch.float32]]
A tuple of two float tensors:
- The processing output of the Conv2dSubsampling module
of shape (1, subsampled_num_frames, output_dim).
- The udated left cache tensor of shape (1, 10, input_dim).
"""
x = torch.cat((cached_left_pad, x), dim=1)
new_cached_left_pad = x[
:,
x.size(1) - self.right_context - cached_left_pad.size(1):
x.size(1) - self.right_context,
]
# Now when we concatenated the left cache with the input, we need to perform the right
# padding of the input in a way to preserve the "same" type of padding, so that the output
# of the module has the same duration as input (taking 2 times subsampling into account).
# There are two possible outcomes depending on whether the the number of input frames is
# even or odd, but both scenarios can be covered by 10 frames right padding.
# x : right padding
# | | | | | | | | | | | |:| | | | | | | | | | input
# | | | | | | | | | | |:| | | | | | | | | first Conv2d output from self.conv
# | | | | | :| | | | second Conv2d output from self.conv
# | | | | :| | | third Conv2d output from self.conv
# | : Conv2d output from
# : self.convnext.depthwise_conv
# :
# x : right padding
# | | | | | | | | | | | | |:| | | | | | | | | | input
# | | | | | | | | | | | |:| | | | | | | | | first Conv2d output from self.conv
# | | | | | |: | | | | second Conv2d output from self.conv
# | | | | |: | | | third Conv2d output from self.conv
# | |: Conv2d output from
# : self.convnext.depthwise_conv
# :
x = torch.cat((x, self.right_pad), dim=1)
# (1, T, input_dim) -> (1, 1, T, input_dim) i.e., (N, C, H, W)
x = x.unsqueeze(1)
x = self.conv(x)
x = self.convnext(x)
# Now x is of shape (1, output_dim, T', ((input_dim - 1) // 2 - 1) // 2)
b, c, t, f = x.size() # b is equal to 1
x = x.permute(0, 2, 1, 3).reshape(b, t, c * f)
# Now x is of shape (T', output_dim * layer3_channels))
x = self.out(x)
# Now x is of shape (T', output_dim)
x = self.out_norm(x)
return x, new_cached_left_pad
def get_init_states(input_dim: int, device: torch.device) -> torch.Tensor:
"""
Get initial states for Conv2dSubsampling module. The Conv2dSubsampling.conv consists of three
consecutive Conv2d layers with the kernel size 3 and no padding, also the middle Conv2d
has a stride 2, while the rest have the default stride 1. We want to pad the input from the
left side with cached_left_pad in the "same" way, so when we pass it through
the Conv2dSubsampling.conv and Conv2dSubsampling.convnext we end up with exactly zero padding
frames from the left.
cached_left_pad : x
| | | | | | | | | |:| | | | | | | | | | | input
| | | | | | | | |:| | | | | | | | | | | first Conv2d output from Conv2dSubsampling.conv
| | | | :| | | | | | ... second Conv2d output from Conv2dSubsampling.conv
| | | :| | | | | | third Conv2d output from Conv2dSubsampling.conv
:| | | | | | Conv2d output from
: Conv2dSubsampling.convnext.depthwise_conv
As we can see from the picture above, in order to preserve the "same"
padding from the left side we need
((((pad - 1) - 1) // 2) - 1) - 3 = 0 --> pad = 10.
Parameters
----------
input_dim : int
The number of input channels.
Corresponds to the number of features in the input of the Conv2dSubsampling module.
device : torch.device
The device used to store the left cache tensor.
Either torch.device("cpu") or torch.device("cuda").
Returns
-------
torch.Tensor[torch.float32]
A left cache float tensor. The output shape is (1, 10, input_dim).
"""
pad = 10
cached_left_pad = torch.full(
(1, pad, input_dim), ZERO_LOG_MEL, dtype=torch.float32, device=device,
)
return cached_left_pad