mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
295 lines
11 KiB
Python
295 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey,
|
|
# Zengwei Yao)
|
|
#
|
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import warnings
|
|
from typing import Tuple
|
|
|
|
import torch
|
|
from scaling import (
|
|
Balancer,
|
|
BiasNorm,
|
|
Dropout3,
|
|
FloatLike,
|
|
Optional,
|
|
ScaledConv2d,
|
|
ScaleGrad,
|
|
ScheduledFloat,
|
|
SwooshL,
|
|
SwooshR,
|
|
Whiten,
|
|
)
|
|
from torch import Tensor, nn
|
|
|
|
|
|
class ConvNeXt(torch.nn.Module):
|
|
"""
|
|
The simplified ConvNeXt module interpretation based on https://arxiv.org/pdf/2206.14747.pdf.
|
|
"""
|
|
|
|
def __init__(self, num_channels: int, device: torch.device) -> None:
|
|
"""
|
|
ConvNeXt initialization.
|
|
|
|
Parameters
|
|
----------
|
|
num_channels : int
|
|
The number of input and output channels for ConvNeXt module.
|
|
device : torch.device
|
|
The device used to store the layer weights.
|
|
Either torch.device("cpu") or torch.device("cuda").
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
self.padding = 3
|
|
hidden_channels = num_channels * 3
|
|
|
|
self.depthwise_conv = torch.nn.Conv2d(
|
|
num_channels,
|
|
num_channels,
|
|
7,
|
|
groups=num_channels,
|
|
padding=(0, self.padding), # time, freq
|
|
device=device,
|
|
)
|
|
|
|
self.activation = SwooshL()
|
|
self.pointwise_conv1 = torch.nn.Conv2d(num_channels, hidden_channels, 1, device=device)
|
|
self.pointwise_conv2 = torch.nn.Conv2d(hidden_channels, num_channels, 1, device=device)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Does a forward pass of the ConvNeXt module.
|
|
|
|
Parameters
|
|
----------
|
|
x : torch.Tensor[torch.float32]
|
|
An input float tensor of shape (1, num_channels, num_input_frames, num_freqs).
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor[torch.float32]
|
|
A output float tensor of the same shape as input,
|
|
(1, num_channels, num_output_frames, num_freqs).
|
|
"""
|
|
|
|
bypass = x[:, :, self.padding: x.size(2) - self.padding]
|
|
|
|
x = self.depthwise_conv(x)
|
|
x = self.pointwise_conv1(x)
|
|
x = self.activation(x)
|
|
x = self.pointwise_conv2(x)
|
|
|
|
x = bypass + x
|
|
|
|
return x
|
|
|
|
|
|
class Conv2dSubsampling(torch.nn.Module):
|
|
"""
|
|
Convolutional 2D subsampling module. It performs the prior subsampling
|
|
(four times subsampling along the frequency axis and two times - along the time axis),
|
|
and low-level descriptor feature extraction from the log mel feature input before passing
|
|
it to zipformer encoder.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_dim: int,
|
|
output_dim: int,
|
|
layer1_channels: int,
|
|
layer2_channels: int,
|
|
layer3_channels: int,
|
|
right_context: int,
|
|
device: torch.device,
|
|
) -> None:
|
|
"""
|
|
Conv2dSubsampling initialization.
|
|
|
|
Parameters
|
|
----------
|
|
input_dim : int
|
|
The number of input channels. Corresponds to the
|
|
number of features in the input feature tensor.
|
|
output_dim : int
|
|
The number of output channels.
|
|
layer1_channels : int
|
|
The number of output channels in the first Conv2d layer.
|
|
layer2_channels : int
|
|
The number of output channels in the second Conv2d layer.
|
|
layer3_channels : int
|
|
The number of output channels in the third Conv2d layer.
|
|
right_context: int
|
|
The look-ahead right context that is used to update the left cache.
|
|
device : torch.device
|
|
The device used to store the layer weights. Should be
|
|
either torch.device("cpu") or torch.device("cuda").
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
if input_dim < 7:
|
|
raise ValueError(
|
|
'The input feature dimension of the Conv2dSubsampling layer, can not be less than '
|
|
'seven, otherwise the frequency subsampling will result with an empty output. '
|
|
f'Expected input_dim to be at least 7 but got {input_dim}.',
|
|
)
|
|
|
|
self.right_context = right_context
|
|
|
|
# Assume batch size is 1 and the right padding is 10,
|
|
# see the forward method on why the right padding is 10.
|
|
self.right_pad = torch.full(
|
|
(1, 10, input_dim), ZERO_LOG_MEL, dtype=torch.float32, device=device,
|
|
)
|
|
self.conv = torch.nn.Sequential(
|
|
torch.nn.Conv2d(
|
|
in_channels=1,
|
|
out_channels=layer1_channels,
|
|
kernel_size=3,
|
|
padding=(0, 1), # (time, freq)
|
|
device=device,
|
|
),
|
|
SwooshR(),
|
|
torch.nn.Conv2d(layer1_channels, layer2_channels, 3, stride=2, device=device),
|
|
SwooshR(),
|
|
torch.nn.Conv2d(layer2_channels, layer3_channels, 3, stride=(1, 2), device=device),
|
|
SwooshR(),
|
|
)
|
|
|
|
self.convnext = ConvNeXt(layer3_channels, device=device)
|
|
|
|
out_width = (((input_dim - 1) // 2) - 1) // 2
|
|
self.out = torch.nn.Linear(out_width * layer3_channels, output_dim, device=device)
|
|
self.out_norm = BiasNorm(output_dim, device=device)
|
|
|
|
def forward(
|
|
self, x: torch.Tensor, cached_left_pad: torch.Tensor,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Does a forward pass of the Conv2dSubsampling module.
|
|
|
|
Parameters
|
|
----------
|
|
x : torch.Tensor[torch.float32]
|
|
An input float tensor of shape (1, num_frames, input_dim). An input feature tensor.
|
|
cached_left_pad : torch.Tensor[torch.float32]
|
|
A left cache float tensor of shape (1, 10, input_dim). Left cache is required
|
|
to preserve the "same" left padding to the output of the Conv2dSubsampling module.
|
|
See the get_init_states() documentation to understand why we need exactly ten frames
|
|
of left padding for the Conv2dSubsampling module.
|
|
|
|
Returns
|
|
-------
|
|
tuple[torch.Tensor[torch.float32], torch.Tensor[torch.float32]]
|
|
A tuple of two float tensors:
|
|
- The processing output of the Conv2dSubsampling module
|
|
of shape (1, subsampled_num_frames, output_dim).
|
|
- The udated left cache tensor of shape (1, 10, input_dim).
|
|
"""
|
|
|
|
x = torch.cat((cached_left_pad, x), dim=1)
|
|
new_cached_left_pad = x[
|
|
:,
|
|
x.size(1) - self.right_context - cached_left_pad.size(1):
|
|
x.size(1) - self.right_context,
|
|
]
|
|
|
|
# Now when we concatenated the left cache with the input, we need to perform the right
|
|
# padding of the input in a way to preserve the "same" type of padding, so that the output
|
|
# of the module has the same duration as input (taking 2 times subsampling into account).
|
|
# There are two possible outcomes depending on whether the the number of input frames is
|
|
# even or odd, but both scenarios can be covered by 10 frames right padding.
|
|
|
|
# x : right padding
|
|
# | | | | | | | | | | | |:| | | | | | | | | | input
|
|
# | | | | | | | | | | |:| | | | | | | | | first Conv2d output from self.conv
|
|
# | | | | | :| | | | second Conv2d output from self.conv
|
|
# | | | | :| | | third Conv2d output from self.conv
|
|
# | : Conv2d output from
|
|
# : self.convnext.depthwise_conv
|
|
# :
|
|
# x : right padding
|
|
# | | | | | | | | | | | | |:| | | | | | | | | | input
|
|
# | | | | | | | | | | | |:| | | | | | | | | first Conv2d output from self.conv
|
|
# | | | | | |: | | | | second Conv2d output from self.conv
|
|
# | | | | |: | | | third Conv2d output from self.conv
|
|
# | |: Conv2d output from
|
|
# : self.convnext.depthwise_conv
|
|
# :
|
|
|
|
x = torch.cat((x, self.right_pad), dim=1)
|
|
|
|
# (1, T, input_dim) -> (1, 1, T, input_dim) i.e., (N, C, H, W)
|
|
x = x.unsqueeze(1)
|
|
x = self.conv(x)
|
|
x = self.convnext(x)
|
|
|
|
# Now x is of shape (1, output_dim, T', ((input_dim - 1) // 2 - 1) // 2)
|
|
b, c, t, f = x.size() # b is equal to 1
|
|
x = x.permute(0, 2, 1, 3).reshape(b, t, c * f)
|
|
# Now x is of shape (T', output_dim * layer3_channels))
|
|
x = self.out(x)
|
|
# Now x is of shape (T', output_dim)
|
|
x = self.out_norm(x)
|
|
|
|
return x, new_cached_left_pad
|
|
|
|
def get_init_states(input_dim: int, device: torch.device) -> torch.Tensor:
|
|
"""
|
|
Get initial states for Conv2dSubsampling module. The Conv2dSubsampling.conv consists of three
|
|
consecutive Conv2d layers with the kernel size 3 and no padding, also the middle Conv2d
|
|
has a stride 2, while the rest have the default stride 1. We want to pad the input from the
|
|
left side with cached_left_pad in the "same" way, so when we pass it through
|
|
the Conv2dSubsampling.conv and Conv2dSubsampling.convnext we end up with exactly zero padding
|
|
frames from the left.
|
|
|
|
cached_left_pad : x
|
|
| | | | | | | | | |:| | | | | | | | | | | input
|
|
| | | | | | | | |:| | | | | | | | | | | first Conv2d output from Conv2dSubsampling.conv
|
|
| | | | :| | | | | | ... second Conv2d output from Conv2dSubsampling.conv
|
|
| | | :| | | | | | third Conv2d output from Conv2dSubsampling.conv
|
|
:| | | | | | Conv2d output from
|
|
: Conv2dSubsampling.convnext.depthwise_conv
|
|
|
|
As we can see from the picture above, in order to preserve the "same"
|
|
padding from the left side we need
|
|
((((pad - 1) - 1) // 2) - 1) - 3 = 0 --> pad = 10.
|
|
|
|
Parameters
|
|
----------
|
|
input_dim : int
|
|
The number of input channels.
|
|
Corresponds to the number of features in the input of the Conv2dSubsampling module.
|
|
device : torch.device
|
|
The device used to store the left cache tensor.
|
|
Either torch.device("cpu") or torch.device("cuda").
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor[torch.float32]
|
|
A left cache float tensor. The output shape is (1, 10, input_dim).
|
|
"""
|
|
|
|
pad = 10
|
|
cached_left_pad = torch.full(
|
|
(1, pad, input_dim), ZERO_LOG_MEL, dtype=torch.float32, device=device,
|
|
)
|
|
|
|
return cached_left_pad
|