mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
2007 lines
82 KiB
Python
2007 lines
82 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey,
|
|
# Zengwei Yao)
|
|
#
|
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import copy
|
|
import logging
|
|
import math
|
|
import random
|
|
import warnings
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
from encoder_interface import EncoderInterface
|
|
from scaling import (
|
|
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
|
)
|
|
from scaling import (
|
|
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
|
)
|
|
from scaling import (
|
|
ActivationDropoutAndLinear,
|
|
Balancer,
|
|
BiasNorm,
|
|
ChunkCausalDepthwiseConv1d,
|
|
Dropout2,
|
|
FloatLike,
|
|
ScheduledFloat,
|
|
Whiten,
|
|
convert_num_channels,
|
|
limit_param_value,
|
|
penalize_abs_values_gt,
|
|
softmax,
|
|
)
|
|
from torch import Tensor, nn
|
|
|
|
|
|
class Zipformer2(torch.nn.Module):
|
|
"""
|
|
Zipformer2 encoder.
|
|
"""
|
|
# pylint: disable=too-many-instance-attributes
|
|
|
|
def __init__(
|
|
self,
|
|
input_dim: int,
|
|
subsample_output_dim: int,
|
|
subsample_layer1_channels: int,
|
|
subsample_layer2_channels: int,
|
|
subsample_layer3_channels: int,
|
|
encoder_dims: list[int],
|
|
num_encoder_layers: list[int],
|
|
downsampling_factors: list[int],
|
|
num_heads: list[int],
|
|
feedforward_dims: list[int],
|
|
cnn_module_kernels: list[int],
|
|
query_head_dim: int,
|
|
pos_head_dim: int,
|
|
value_head_dim: int,
|
|
pos_dim: int,
|
|
pos_max_len: int,
|
|
output_dim: int,
|
|
use_ctc: bool,
|
|
left_context_frames: int,
|
|
right_context_frames: int,
|
|
device: torch.device,
|
|
) -> None:
|
|
"""
|
|
Zipformer2 initialization.
|
|
|
|
Parameters
|
|
----------
|
|
input_dim : int
|
|
The number of input features.
|
|
subsample_output_dim : int
|
|
The output dimension of the subsampling module represented by Conv2dSubsampling.
|
|
subsample_layer1_channels : int
|
|
The number of output channels in the first Conv2d layer of the
|
|
Conv2dSubsampling module.
|
|
subsample_layer2_channels : int
|
|
The number of output channels in the second Conv2d layer of the
|
|
Conv2dSubsampling module.
|
|
subsample_layer3_channels : int
|
|
The number of output channels in the third Conv2d layer of the
|
|
Conv2dSubsampling module.
|
|
encoder_dims : list[int]
|
|
A list of 5 integers, the embedding dimension of
|
|
Zipformer2EncoderLayer module in each Zipformer2Encoder stack.
|
|
num_encoder_layers : list[int]
|
|
A list of 5 integers, the number of Zipformer2EncoderLayer
|
|
modules in each Zipformer2Encoder stack.
|
|
downsampling_factors : list[int]
|
|
A list of 5 integers, the downsampling factor of each Zipformer2Encoder stack.
|
|
Note: this is in addition to the downsampling factor of 2 that is applied in the
|
|
Conv2dSubsampling module.
|
|
num_heads : list[int]
|
|
A list of 5 integers, the number of heads for attention weights and self-attention of
|
|
the Zipformer2EncoderLayer module in each Zipformer2Encoder stack.
|
|
feedforward_dims : list[int]
|
|
A list of 5 integers, the hidden dimension of the feedforward module of
|
|
the Zipformer2EncoderLayer module in each Zipformer2Encoder stack.
|
|
cnn_module_kernels : list[int]
|
|
A list of 5 integers, the kernel size of the convolution module of
|
|
the Zipformer2EncoderLayer module in each Zipformer2Encoder stack.
|
|
query_head_dim : int
|
|
The dimension of the query and key per attention head in attention weights of the
|
|
Zipformer2EncoderLayer module in each Zipformer2Encoder stack.
|
|
pos_head_dim : int
|
|
The dimension of the projected positional encoding per attention head in attention
|
|
weights of the Zipformer2EncoderLayer module in each Zipformer2Encoder stack.
|
|
value_head_dim : int
|
|
The dimension of the value per attention head in self-attention of
|
|
the Zipformer2EncoderLayer module in each Zipformer2Encoder stack.
|
|
pos_dim: int
|
|
The dimension of the relative positional embeddings in each Zipformer2Encoder stack.
|
|
pos_max_len : int
|
|
The maximum input duration of the relative positional embeddings in each
|
|
Zipformer2Encoder stack. Note: if the input duration of any positional embedding module
|
|
exceeds this number, then one might end up with a big degradation of inference speed.
|
|
output_dim : int
|
|
The output dimension after final output projection.
|
|
use_ctc : bool
|
|
If True, assuming that ctc head will loaded to the output encoder projection.
|
|
In this case torch.nn.functional. will be applied to the output at the very end.
|
|
left_context_frames : int
|
|
The left context number of frames after the initial subsampling with
|
|
Conv2dSubsampling module.
|
|
right_context_frames : int
|
|
The right (look-ahead) context number of frames.
|
|
device : torch.device
|
|
The device used to store the layer weights. Should be
|
|
either torch.device("cpu") or torch.device("cuda").
|
|
"""
|
|
# pylint: disable=too-many-arguments,too-many-locals
|
|
|
|
super().__init__()
|
|
|
|
if not (
|
|
len(encoder_dims)
|
|
== len(num_encoder_layers)
|
|
== len(downsampling_factors)
|
|
== len(num_heads)
|
|
== len(feedforward_dims)
|
|
== len(cnn_module_kernels)
|
|
== 6
|
|
):
|
|
raise ValueError(
|
|
'It is required that the length of encoder_dims, num_encoder_layers, '
|
|
'downsampling_factors, num_heads, feedforward_dims, and cnn_module_kernels is the '
|
|
'same and equal to 6, but got following list lengths:\n'
|
|
f'len(num_encoder_layers) == {len(num_encoder_layers)}\n'
|
|
f'len(downsampling_factors) == {len(downsampling_factors)}\n'
|
|
f'len(encoder_dims) == {len(encoder_dims)}\n'
|
|
f'len(num_heads) == {len(num_heads)}\n'
|
|
f'len(cnn_module_kernels) == {len(cnn_module_kernels)}\n'
|
|
f'len(feedforward_dims) == {len(feedforward_dims)}.',
|
|
)
|
|
|
|
self.encoder_dims = tuple(encoder_dims)
|
|
self.downsampling_factors = tuple(downsampling_factors)
|
|
self.left_context_frames = left_context_frames
|
|
projection_dim = max(encoder_dims)
|
|
self.projection_dim = projection_dim
|
|
self.ctc = use_ctc
|
|
|
|
self.subsampling = Conv2dSubsampling(
|
|
input_dim,
|
|
subsample_output_dim,
|
|
subsample_layer1_channels,
|
|
subsample_layer2_channels,
|
|
subsample_layer3_channels,
|
|
right_context_frames,
|
|
device,
|
|
)
|
|
|
|
encoders = []
|
|
for i, num_layers in enumerate(num_encoder_layers):
|
|
|
|
encoder_layer = Zipformer2EncoderLayer(
|
|
encoder_dims[i],
|
|
pos_dim,
|
|
num_heads[i],
|
|
query_head_dim,
|
|
pos_head_dim,
|
|
value_head_dim,
|
|
feedforward_dims[i],
|
|
cnn_module_kernels[i],
|
|
left_context_frames // downsampling_factors[i],
|
|
right_context_frames // 2 // downsampling_factors[i],
|
|
device,
|
|
)
|
|
|
|
encoder = Zipformer2Encoder(
|
|
encoder_layer,
|
|
num_layers,
|
|
encoder_dims[i],
|
|
pos_dim,
|
|
pos_max_len,
|
|
downsampling_factors[i],
|
|
device,
|
|
)
|
|
|
|
encoders.append(encoder)
|
|
|
|
self.encoder_1 = encoders[0]
|
|
self.encoder_2 = encoders[1]
|
|
self.encoder_3 = encoders[2]
|
|
self.encoder_4 = encoders[3]
|
|
self.encoder_5 = encoders[4]
|
|
self.encoder_6 = encoders[5]
|
|
|
|
self.downsample_output = SimpleDownsample(2, device)
|
|
self.projection_output = torch.nn.Linear(projection_dim, output_dim, device=device)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
# We need to preserve this explicit arguments reference for the smooth
|
|
# TirchScript export with the following ONNX export.
|
|
left_cached_subsample_frames: torch.Tensor,
|
|
|
|
left_cached_keys_encoder_1: torch.Tensor,
|
|
left_cached_nonlin_attentions_encoder_1: torch.Tensor,
|
|
left_cached_values_1_encoder_1: torch.Tensor,
|
|
left_cached_values_2_encoder_1: torch.Tensor,
|
|
left_cached_convolutions_1_encoder_1: torch.Tensor,
|
|
left_cached_convolutions_2_encoder_1: torch.Tensor,
|
|
|
|
left_cached_keys_encoder_2: torch.Tensor,
|
|
left_cached_nonlin_attentions_encoder_2: torch.Tensor,
|
|
left_cached_values_1_encoder_2: torch.Tensor,
|
|
left_cached_values_2_encoder_2: torch.Tensor,
|
|
left_cached_convolutions_1_encoder_2: torch.Tensor,
|
|
left_cached_convolutions_2_encoder_2: torch.Tensor,
|
|
|
|
left_cached_keys_encoder_3: torch.Tensor,
|
|
left_cached_nonlin_attentions_encoder_3: torch.Tensor,
|
|
left_cached_values_1_encoder_3: torch.Tensor,
|
|
left_cached_values_2_encoder_3: torch.Tensor,
|
|
left_cached_convolutions_1_encoder_3: torch.Tensor,
|
|
left_cached_convolutions_2_encoder_3: torch.Tensor,
|
|
|
|
left_cached_keys_encoder_4: torch.Tensor,
|
|
left_cached_nonlin_attentions_encoder_4: torch.Tensor,
|
|
left_cached_values_1_encoder_4: torch.Tensor,
|
|
left_cached_values_2_encoder_4: torch.Tensor,
|
|
left_cached_convolutions_1_encoder_4: torch.Tensor,
|
|
left_cached_convolutions_2_encoder_4: torch.Tensor,
|
|
|
|
left_cached_keys_encoder_5: torch.Tensor,
|
|
left_cached_nonlin_attentions_encoder_5: torch.Tensor,
|
|
left_cached_values_1_encoder_5: torch.Tensor,
|
|
left_cached_values_2_encoder_5: torch.Tensor,
|
|
left_cached_convolutions_1_encoder_5: torch.Tensor,
|
|
left_cached_convolutions_2_encoder_5: torch.Tensor,
|
|
|
|
left_cached_keys_encoder_6: torch.Tensor,
|
|
left_cached_nonlin_attentions_encoder_6: torch.Tensor,
|
|
left_cached_values_1_encoder_6: torch.Tensor,
|
|
left_cached_values_2_encoder_6: torch.Tensor,
|
|
left_cached_convolutions_1_encoder_6: torch.Tensor,
|
|
left_cached_convolutions_2_encoder_6: torch.Tensor,
|
|
|
|
processed_len: torch.Tensor,
|
|
) -> tuple[
|
|
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
|
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
|
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
|
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
|
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
|
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
|
torch.Tensor, torch.Tensor, torch.Tensor,
|
|
]:
|
|
"""
|
|
Does a forward pass of the Zipformer2 module, which represents the whole acoustic encoder.
|
|
Returns a tuple with the output tensor, updated left cache feature tensor for subsampling
|
|
module, 36 left cache tensors for multiple attention and convolution modules within each of
|
|
6 Zipformer2Encoder modules, and finally, the updated processed length single-element
|
|
tensor with the total number of processed frames after subsampling module.
|
|
|
|
Parameters
|
|
----------
|
|
x : torch.Tensor[torch.float32]
|
|
The input float feature tensor of shape (1, num_frames, input_dim),
|
|
where the input_dim corresponds to the number of features.
|
|
left_cached_subsample_frames : torch.Tensor[torch.float32]
|
|
The subsampling module left cache tensor of shape (1, 10, input_dim).
|
|
left_cached_keys_encoder_1 : torch.Tensor[torch.float32]
|
|
The cached attention key tensor of the left context of each
|
|
Zipformer2EncoderLayer within the first Zipformer2Encoder.
|
|
The tensor is of shape (num_layers_1, 1, left_context_len_1, query_dim_1).
|
|
left_cached_nonlin_attentions_encoder_1 : torch.Tensor[torch.float32]
|
|
The left context cached attention tensor for the non-linear attention module of each
|
|
Zipformer2EncoderLayer within the first Zipformer2Encoder.
|
|
The tensor is of shape (num_layers_1, 1, left_context_len_1, head_dim_1).
|
|
left_cached_values_1_encoder_1 : torch.Tensor[torch.float32]
|
|
The cached left context tensor for the first self-attention module of each
|
|
Zipformer2EncoderLayer within the first Zipformer2Encoder.
|
|
The tensor is of shape (num_layers_1, 1, left_context_len_1, value_dim_1).
|
|
left_cached_values_2_encoder_1 : torch.Tensor[torch.float32]
|
|
The cached left context tensor for the second self-attention module of each
|
|
Zipformer2EncoderLayer within the first Zipformer2Encoder.
|
|
The tensor is of shape (num_layers_1, 1, left_context_len_1, value_dim_1).
|
|
left_cached_convolutions_1_encoder_1 : torch.Tensor[torch.float32]
|
|
The cached left context tensor for the first convolution module of each
|
|
Zipformer2EncoderLayer within the first Zipformer2Encoder.
|
|
The tensor is of shape (num_layers_1, 1, embed_dim_1, left_cache_len_1).
|
|
left_cached_convolutions_2_encoder_1 : torch.Tensor[torch.float32]
|
|
The cached left context tensor for the second convolution module of each
|
|
Zipformer2EncoderLayer within the first Zipformer2Encoder.
|
|
The tensor is of shape (num_layers_1, 1, embed_dim_1, left_cache_len_1).
|
|
.
|
|
.
|
|
.
|
|
left_cached_convolutions_2_encoder_6 : torch.Tensor[torch.float32]
|
|
The cached left context tensor for the second convolution module of each
|
|
Zipformer2EncoderLayer within the sixth Zipformer2Encoder.
|
|
The tensor is of shape (num_layers_6, 1, embed_dim_6, left_cache_len_6).
|
|
processed_len : torch.Tensor[torch.int32]
|
|
The total processed length after subsampling, single-element integer tensor
|
|
of shape (1,).
|
|
|
|
Returns
|
|
-------
|
|
tuple[
|
|
torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32],
|
|
torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32],
|
|
torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32],
|
|
torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32],
|
|
torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32],
|
|
torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32],
|
|
torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32],
|
|
torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32],
|
|
torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32],
|
|
torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32],
|
|
torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32],
|
|
torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32],
|
|
torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.int32],
|
|
]
|
|
A tuple of 38 float tensors and 1 integer tensor:
|
|
- The module output of shape (1, seq_len, output_dim).
|
|
- The updated subsampling module left cache tensor of shape (1, 10, input_dim).
|
|
- The updated cached attention key tensor of the left context of each
|
|
Zipformer2EncoderLayer within the first Zipformer2Encoder.
|
|
The tensor is of shape (num_layers_1, 1, left_context_len_1, query_dim_1).
|
|
- The updated left context cached attention tensor for the non-linear attention
|
|
module of each Zipformer2EncoderLayer within the first Zipformer2Encoder.
|
|
The tensor is of shape (num_layers_1, 1, left_context_len_1, head_dim_1).
|
|
- The updated cached left context tensor for the first self-attention module of each
|
|
Zipformer2EncoderLayer within the first Zipformer2Encoder.
|
|
The tensor is of shape (num_layers_1, 1, left_context_len_1, value_dim_1).
|
|
- The updated cached left context tensor for the second
|
|
self-attention module of each Zipformer2EncoderLayer within the first
|
|
Zipformer2Encoder.
|
|
The tensor is of shape (num_layers_1, 1, left_context_len_1, value_dim_1).
|
|
- The updated cached left context tensor for the first convolution module of each
|
|
Zipformer2EncoderLayer within the first Zipformer2Encoder.
|
|
The tensor is of shape (num_layers_1, 1, embed_dim_1, left_cache_len_1).
|
|
- The updated cached left context tensor for the second convolution module of each
|
|
Zipformer2EncoderLayer within the first Zipformer2Encoder.
|
|
The tensor is of shape (num_layers_1, 1, embed_dim_1, left_cache_len_1).
|
|
.
|
|
.
|
|
.
|
|
- The updated cached left context tensor for the second convolution module of each
|
|
Zipformer2EncoderLayer within the sixth Zipformer2Encoder.
|
|
The tensor is of shape (num_layers_6, 1, embed_dim_6, left_cache_len_6).
|
|
- The updated total processed length tensor after subsampling of shape (1,).
|
|
"""
|
|
# pylint: disable=too-many-arguments,too-many-locals
|
|
|
|
x, new_left_cached_subsample_frames = self.subsampling(x, left_cached_subsample_frames)
|
|
|
|
batch_size, seq_len, _ = x.size()
|
|
src_key_padding_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=x.device)
|
|
|
|
# processed_mask is used to mask out the initial self.states if left_context_frames == 6,
|
|
# then tensor will contain [5, 4, 3, 2, 1, 0] as if reversed
|
|
# torch.arange(left_context_frames).
|
|
processed_mask = torch.arange(
|
|
self.left_context_frames - 1, -1, -1, dtype=torch.int32, device=x.device,
|
|
).expand(batch_size, self.left_context_frames)
|
|
|
|
# (1, left_context_size) i.e. (batch_size, left_context_size)
|
|
processed_mask = processed_mask >= processed_len.expand(processed_mask.size())
|
|
|
|
# Update processed lengths
|
|
new_processed_len = processed_len + seq_len
|
|
|
|
# (1, left_context_size + chunk_size)
|
|
src_key_padding_mask = torch.cat((processed_mask, src_key_padding_mask), dim=1)
|
|
|
|
# If the last encoder 'x' has the largest dimension, then the 'output' will be just this
|
|
# last 'x' unchanged. Otherwise it will be concatenated from different pieces of 'x',
|
|
# taking each output channel dimension from the most recent x that has it present.
|
|
output = torch.empty(
|
|
batch_size, seq_len, self.projection_dim, dtype=torch.float32, device=x.device,
|
|
)
|
|
|
|
# We have a number of Zipformer2Encoder stacks fixed and equal to 6 for any Ziformer2 size
|
|
# including small, medium and large. For the sake of smoother model TorchScript export we
|
|
# engage sequential explicit forward call of each Zipformer2Encoder module instead of using
|
|
# torch.nn.ModuleList.
|
|
|
|
# Encoder 1
|
|
|
|
(
|
|
x,
|
|
new_left_cached_keys_encoder_1,
|
|
new_left_cached_nonlin_attentions_encoder_1,
|
|
new_left_cached_values_1_encoder_1,
|
|
new_left_cached_values_2_encoder_1,
|
|
new_left_cached_convolutions_1_encoder_1,
|
|
new_left_cached_convolutions_2_encoder_1,
|
|
) = self.encoder_1(
|
|
x,
|
|
left_cached_keys_encoder_1,
|
|
left_cached_nonlin_attentions_encoder_1,
|
|
left_cached_values_1_encoder_1,
|
|
left_cached_values_2_encoder_1,
|
|
left_cached_convolutions_1_encoder_1,
|
|
left_cached_convolutions_2_encoder_1,
|
|
src_key_padding_mask[:, ::self.downsampling_factors[0]],
|
|
)
|
|
output[:, :, :x.size(2)] = x
|
|
|
|
# Encoder 2
|
|
|
|
pad = torch.zeros(
|
|
x.size(0), x.size(1), self.encoder_dims[1] - x.size(2),
|
|
dtype=torch.float32,
|
|
device=x.device,
|
|
)
|
|
x = torch.cat((x, pad), dim=2)
|
|
|
|
(
|
|
x,
|
|
new_left_cached_keys_encoder_2,
|
|
new_left_cached_nonlin_attentions_encoder_2,
|
|
new_left_cached_values_1_encoder_2,
|
|
new_left_cached_values_2_encoder_2,
|
|
new_left_cached_convolutions_1_encoder_2,
|
|
new_left_cached_convolutions_2_encoder_2,
|
|
) = self.encoder_2(
|
|
x,
|
|
left_cached_keys_encoder_2,
|
|
left_cached_nonlin_attentions_encoder_2,
|
|
left_cached_values_1_encoder_2,
|
|
left_cached_values_2_encoder_2,
|
|
left_cached_convolutions_1_encoder_2,
|
|
left_cached_convolutions_2_encoder_2,
|
|
src_key_padding_mask[:, ::self.downsampling_factors[1]],
|
|
)
|
|
output[:, :, :x.size(2)] = x
|
|
|
|
# Encoder 3
|
|
|
|
pad = torch.zeros(
|
|
x.size(0), x.size(1), self.encoder_dims[2] - x.size(2),
|
|
dtype=torch.float32,
|
|
device=x.device,
|
|
)
|
|
x = torch.cat((x, pad), dim=2)
|
|
|
|
(
|
|
x,
|
|
new_left_cached_keys_encoder_3,
|
|
new_left_cached_nonlin_attentions_encoder_3,
|
|
new_left_cached_values_1_encoder_3,
|
|
new_left_cached_values_2_encoder_3,
|
|
new_left_cached_convolutions_1_encoder_3,
|
|
new_left_cached_convolutions_2_encoder_3,
|
|
) = self.encoder_3(
|
|
x,
|
|
left_cached_keys_encoder_3,
|
|
left_cached_nonlin_attentions_encoder_3,
|
|
left_cached_values_1_encoder_3,
|
|
left_cached_values_2_encoder_3,
|
|
left_cached_convolutions_1_encoder_3,
|
|
left_cached_convolutions_2_encoder_3,
|
|
src_key_padding_mask[:, ::self.downsampling_factors[2]],
|
|
)
|
|
output[:, :, :x.size(2)] = x
|
|
|
|
# Encoder 4
|
|
|
|
pad = torch.zeros(
|
|
x.size(0), x.size(1), self.encoder_dims[3] - x.size(2),
|
|
dtype=torch.float32,
|
|
device=x.device,
|
|
)
|
|
x = torch.cat((x, pad), dim=2)
|
|
|
|
(
|
|
x,
|
|
new_left_cached_keys_encoder_4,
|
|
new_left_cached_nonlin_attentions_encoder_4,
|
|
new_left_cached_values_1_encoder_4,
|
|
new_left_cached_values_2_encoder_4,
|
|
new_left_cached_convolutions_1_encoder_4,
|
|
new_left_cached_convolutions_2_encoder_4,
|
|
) = self.encoder_4(
|
|
x,
|
|
left_cached_keys_encoder_4,
|
|
left_cached_nonlin_attentions_encoder_4,
|
|
left_cached_values_1_encoder_4,
|
|
left_cached_values_2_encoder_4,
|
|
left_cached_convolutions_1_encoder_4,
|
|
left_cached_convolutions_2_encoder_4,
|
|
src_key_padding_mask[:, ::self.downsampling_factors[3]],
|
|
)
|
|
output[:, :, :x.size(2)] = x
|
|
|
|
# Encoder 5
|
|
|
|
x = x[:, :, :self.encoder_dims[4]]
|
|
(
|
|
x,
|
|
new_left_cached_keys_encoder_5,
|
|
new_left_cached_nonlin_attentions_encoder_5,
|
|
new_left_cached_values_1_encoder_5,
|
|
new_left_cached_values_2_encoder_5,
|
|
new_left_cached_convolutions_1_encoder_5,
|
|
new_left_cached_convolutions_2_encoder_5,
|
|
) = self.encoder_5(
|
|
x,
|
|
left_cached_keys_encoder_5,
|
|
left_cached_nonlin_attentions_encoder_5,
|
|
left_cached_values_1_encoder_5,
|
|
left_cached_values_2_encoder_5,
|
|
left_cached_convolutions_1_encoder_5,
|
|
left_cached_convolutions_2_encoder_5,
|
|
src_key_padding_mask[:, ::self.downsampling_factors[4]],
|
|
)
|
|
output[:, :, :x.size(2)] = x
|
|
|
|
# Encoder 6
|
|
|
|
x = x[:, :, :self.encoder_dims[5]]
|
|
(
|
|
x,
|
|
new_left_cached_keys_encoder_6,
|
|
new_left_cached_nonlin_attentions_encoder_6,
|
|
new_left_cached_values_1_encoder_6,
|
|
new_left_cached_values_2_encoder_6,
|
|
new_left_cached_convolutions_1_encoder_6,
|
|
new_left_cached_convolutions_2_encoder_6,
|
|
) = self.encoder_6(
|
|
x,
|
|
left_cached_keys_encoder_6,
|
|
left_cached_nonlin_attentions_encoder_6,
|
|
left_cached_values_1_encoder_6,
|
|
left_cached_values_2_encoder_6,
|
|
left_cached_convolutions_1_encoder_6,
|
|
left_cached_convolutions_2_encoder_6,
|
|
src_key_padding_mask[:, ::self.downsampling_factors[5]],
|
|
)
|
|
output[:, :, :x.size(2)] = x
|
|
|
|
output = self.downsample_output(output)
|
|
output = self.projection_output(output)
|
|
if self.ctc:
|
|
output = torch.nn.functional.log_softmax(output, dim=2)
|
|
|
|
return (
|
|
output,
|
|
# Because of the reasons mentioned in previous comments,
|
|
# for the sake of easier TorchScript and ONNX export we
|
|
# preserve the explicit listing of each left cache tensor.
|
|
new_left_cached_subsample_frames,
|
|
|
|
new_left_cached_keys_encoder_1,
|
|
new_left_cached_nonlin_attentions_encoder_1,
|
|
new_left_cached_values_1_encoder_1,
|
|
new_left_cached_values_2_encoder_1,
|
|
new_left_cached_convolutions_1_encoder_1,
|
|
new_left_cached_convolutions_2_encoder_1,
|
|
|
|
new_left_cached_keys_encoder_2,
|
|
new_left_cached_nonlin_attentions_encoder_2,
|
|
new_left_cached_values_1_encoder_2,
|
|
new_left_cached_values_2_encoder_2,
|
|
new_left_cached_convolutions_1_encoder_2,
|
|
new_left_cached_convolutions_2_encoder_2,
|
|
|
|
new_left_cached_keys_encoder_3,
|
|
new_left_cached_nonlin_attentions_encoder_3,
|
|
new_left_cached_values_1_encoder_3,
|
|
new_left_cached_values_2_encoder_3,
|
|
new_left_cached_convolutions_1_encoder_3,
|
|
new_left_cached_convolutions_2_encoder_3,
|
|
|
|
new_left_cached_keys_encoder_4,
|
|
new_left_cached_nonlin_attentions_encoder_4,
|
|
new_left_cached_values_1_encoder_4,
|
|
new_left_cached_values_2_encoder_4,
|
|
new_left_cached_convolutions_1_encoder_4,
|
|
new_left_cached_convolutions_2_encoder_4,
|
|
|
|
new_left_cached_keys_encoder_5,
|
|
new_left_cached_nonlin_attentions_encoder_5,
|
|
new_left_cached_values_1_encoder_5,
|
|
new_left_cached_values_2_encoder_5,
|
|
new_left_cached_convolutions_1_encoder_5,
|
|
new_left_cached_convolutions_2_encoder_5,
|
|
|
|
new_left_cached_keys_encoder_6,
|
|
new_left_cached_nonlin_attentions_encoder_6,
|
|
new_left_cached_values_1_encoder_6,
|
|
new_left_cached_values_2_encoder_6,
|
|
new_left_cached_convolutions_1_encoder_6,
|
|
new_left_cached_convolutions_2_encoder_6,
|
|
|
|
new_processed_len,
|
|
)
|
|
|
|
|
|
def get_init_states(
|
|
input_dim: int,
|
|
num_encoder_layers: list[int],
|
|
downsample_left_pad_frames: list[int],
|
|
encoder_dims: list[int],
|
|
query_dims: list[int],
|
|
value_dims: list[int],
|
|
head_dims: list[int],
|
|
convolution_left_pad_frames: list[int],
|
|
device: torch.device,
|
|
) -> list[torch.Tensor]:
|
|
"""
|
|
Get initial states for the Zipformer2 encoder. The method generates a list of torch tensors,
|
|
where the first tensor corresponds to a subsampling module left cache. Next, for each
|
|
Zipformer2Encoder module we add six cache tensors that are essential for multi-head attention
|
|
and convolution modules. Finally, at the end we append a total processed frames tensor,
|
|
initialized with zero.
|
|
|
|
Parameters
|
|
----------
|
|
input_dim : int
|
|
The number of input features.
|
|
num_encoder_layers : list[int]
|
|
The number of Zipformer2EncoderLayer modules for each Zipformer2Encoder stack.
|
|
downsample_left_pad_frames : list[int]
|
|
The multi-head attention left context cache frames after downsampling.
|
|
encoder_dims : list[int]
|
|
The embedding dimension for each Zipformer2Encoder stack.
|
|
query_dims : list[int]
|
|
The multi-head attention query dimension for each Zipformer2Encoder stack.
|
|
value_dims : list[int]
|
|
The multi-head attention value dimension for each Zipformer2Encoder stack.
|
|
head_dims : list[int]
|
|
The non-linear attention head dimension for each Zipformer2Encoder stack.
|
|
convolution_left_pad_frames : list[int]
|
|
The convolution modules left padding number of frames for each Zipformer2Encoder stack.
|
|
device : torch.device
|
|
The device used to store cache tensors. Should be
|
|
either torch.device("cpu") or torch.device("cuda").
|
|
|
|
Returns
|
|
-------
|
|
list[torch.Tensor[torch.float32 | torch.int32]]
|
|
A list of left cache tensors.
|
|
- A subsampling module left cache tensor of shape (1, 10, input_dim)
|
|
- The first Zipformer2Encoder cached attention key tensor of the left context in each
|
|
Zipformer2EncoderLayer of the stack.
|
|
The tensor is of shape (num_layers_1, 1, left_context_len_1, query_dim_1).
|
|
- The first Zipformer2Encoder left context cached attention tensor for the non-linear
|
|
attention module in each Zipformer2EncoderLayer of the stack.
|
|
The tensor is of shape (num_layers_1, 1, left_context_len_1, head_dim_1).
|
|
- The first Zipformer2Encoder cached left context tensor for the first self-attention
|
|
module in each Zipformer2EncoderLayer of the stack.
|
|
The tensor is of shape (num_layers_1, 1, left_context_len_1, value_dim_1).
|
|
- The first Zipformer2Encoder cached left context tensor for the second self-attention
|
|
module in each Zipformer2EncoderLayer of the stack.
|
|
The tensor is of shape (num_layers_1, 1, left_context_len_1, value_dim_1).
|
|
- The first Zipformer2Encoder cached left context tensor for the first convolution module
|
|
in each Zipformer2EncoderLayer of the stack.
|
|
The tensor is of shape (num_layers_1, 1, encoder_dim_1, conv_left_pad_1).
|
|
- The first Zipformer2Encoder cached left context tensor for the second convolution module
|
|
in each Zipformer2EncoderLayer of the stack.
|
|
The tensor is of shape (num_layers_1, 1, encoder_dim_1, conv_left_pad_1).
|
|
.
|
|
.
|
|
.
|
|
- The sixth Zipformer2Encoder cached left context tensor for the second convolution module
|
|
in each Zipformer2EncoderLayer of the stack.
|
|
The tensor is of shape (num_layers_6, 1, encoder_dim_6, conv_left_pad_6).
|
|
- The processed length integer tensor initialized with a single zero element.
|
|
The tensor is of shape (1,).
|
|
"""
|
|
# pylint: disable=too-many-locals
|
|
|
|
if not (
|
|
len(num_encoder_layers)
|
|
== len(downsample_left_pad_frames)
|
|
== len(encoder_dims)
|
|
== len(query_dims)
|
|
== len(value_dims)
|
|
== len(head_dims)
|
|
== len(convolution_left_pad_frames)
|
|
):
|
|
raise ValueError(
|
|
'It is required that all encoder parameter lists have the same '
|
|
'length, but got following parameter list lengths:\n'
|
|
f'len(num_encoder_layers) == {len(num_encoder_layers)}\n'
|
|
f'len(downsample_left_pad_frames) == {len(downsample_left_pad_frames)}\n'
|
|
f'len(encoder_dims) == {len(encoder_dims)}\n'
|
|
f'len(query_dims) == {len(query_dims)}\n'
|
|
f'len(value_dims) == {len(value_dims)}\n'
|
|
f'len(nonlin_attn_head_dims) == {len(head_dims)}\n'
|
|
f'len(convolution_left_pad_frames) == {len(convolution_left_pad_frames)}.',
|
|
)
|
|
|
|
states = [subsampling_get_init_states(input_dim, device)]
|
|
for i, num_layers in enumerate(num_encoder_layers):
|
|
|
|
encoder_dim = encoder_dims[i]
|
|
query_dim = query_dims[i]
|
|
value_dim = value_dims[i]
|
|
head_dim = head_dims[i]
|
|
left_context_len = downsample_left_pad_frames[i]
|
|
left_cache_len = convolution_left_pad_frames[i]
|
|
|
|
# batch size is 1
|
|
states += [
|
|
torch.zeros(
|
|
num_layers, 1, left_context_len, query_dim, dtype=torch.float32, device=device,
|
|
),
|
|
torch.zeros(
|
|
num_layers, 1, left_context_len, head_dim, dtype=torch.float32, device=device,
|
|
),
|
|
torch.zeros(
|
|
num_layers, 1, left_context_len, value_dim, dtype=torch.float32, device=device,
|
|
),
|
|
torch.zeros(
|
|
num_layers, 1, left_context_len, value_dim, dtype=torch.float32, device=device,
|
|
),
|
|
torch.zeros(
|
|
num_layers, 1, encoder_dim, left_cache_len, dtype=torch.float32, device=device,
|
|
),
|
|
torch.zeros(
|
|
num_layers, 1, encoder_dim, left_cache_len, dtype=torch.float32, device=device,
|
|
),
|
|
]
|
|
|
|
states.append(torch.zeros(1, dtype=torch.int32, device=device))
|
|
|
|
return states
|
|
|
|
|
|
class Zipformer2EncoderLayer(torch.nn.Module):
|
|
"""
|
|
Zipformer2EncoderLayer module, the basic block of Zipformer2Encoder encoder stack.
|
|
"""
|
|
# pylint: disable=too-many-instance-attributes
|
|
|
|
def __init__(
|
|
self,
|
|
embed_dim: int,
|
|
pos_dim: int,
|
|
num_heads: int,
|
|
query_head_dim: int,
|
|
pos_head_dim: int,
|
|
value_head_dim: int,
|
|
feedforward_dim: int,
|
|
cnn_module_kernel: int,
|
|
left_context_len: int,
|
|
right_context_len: int,
|
|
device: torch.device,
|
|
) -> None:
|
|
"""
|
|
Zipformer2EncoderLayer initialization.
|
|
|
|
Parameters
|
|
----------
|
|
embed_dim : int
|
|
The input and output embedding dimension. The number of channels is the same for input
|
|
and output of this module.
|
|
pos_dim : int
|
|
The dimension of the relative positional embedding.
|
|
num_heads : int
|
|
The number of heads for attention weights and self-attention.
|
|
query_head_dim : int
|
|
The dimension of the query and key per attention head in attention weights.
|
|
pos_head_dim: int
|
|
The dimension of the projected positional encoding
|
|
per attention head in attention weights.
|
|
value_head_dim : int
|
|
The dimension of the value per attention head in self-attention.
|
|
feedforward_dim : int
|
|
The hidden dimension of the feedforward modules.
|
|
cnn_module_kernel : int
|
|
The kernel size of the convolution modules.
|
|
left_context_len : int
|
|
The module left context number of subsampled frames.
|
|
right_context_len : int
|
|
The module right context number of subsampled frames.
|
|
Used to update attention and convolution left caches.
|
|
device : torch.device
|
|
The device used to store the layer weights. Should be
|
|
either torch.device("cpu") or torch.device("cuda").
|
|
"""
|
|
# pylint: disable=too-many-arguments
|
|
|
|
super().__init__()
|
|
|
|
self.left_context_len = left_context_len
|
|
|
|
# self.bypass implements the whole layer skipping.
|
|
self.bypass = BypassModule(embed_dim, device)
|
|
# bypass_mid is bypass used in the middle of the layer.
|
|
self.bypass_mid = BypassModule(embed_dim, device)
|
|
|
|
self.self_attn_weights = RelPositionMultiheadAttentionWeights(
|
|
embed_dim, pos_dim, num_heads, query_head_dim, pos_head_dim, right_context_len, device,
|
|
)
|
|
|
|
self.self_attn1 = SelfAttention(
|
|
embed_dim, num_heads, value_head_dim, right_context_len, device,
|
|
)
|
|
self.self_attn2 = SelfAttention(
|
|
embed_dim, num_heads, value_head_dim, right_context_len, device,
|
|
)
|
|
|
|
self.nonlin_attention = NonlinAttention(
|
|
embed_dim, 3 * embed_dim // 4, right_context_len, device,
|
|
)
|
|
|
|
self.feed_forward1 = FeedforwardModule(embed_dim, (feedforward_dim * 3) // 4, device)
|
|
self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, device)
|
|
self.feed_forward3 = FeedforwardModule(embed_dim, (feedforward_dim * 5) // 4, device)
|
|
|
|
self.conv_module1 = ConvolutionModule(
|
|
embed_dim, cnn_module_kernel, right_context_len, device,
|
|
)
|
|
self.conv_module2 = ConvolutionModule(
|
|
embed_dim, cnn_module_kernel, right_context_len, device,
|
|
)
|
|
|
|
self.norm = BiasNorm(embed_dim, device)
|
|
|
|
def forward(
|
|
self,
|
|
src: torch.Tensor,
|
|
pos_emb: torch.Tensor,
|
|
left_cached_key: torch.Tensor,
|
|
left_cached_nonlin_attn: torch.Tensor,
|
|
left_cached_val_1: torch.Tensor,
|
|
left_cached_val_2: torch.Tensor,
|
|
left_cached_conv_1: torch.Tensor,
|
|
left_cached_conv_2: torch.Tensor,
|
|
src_key_padding_mask: torch.Tensor,
|
|
) -> tuple[
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
]:
|
|
"""
|
|
Does a forward pass of the Zipformer2EncoderLayer module. Returns an output tensor with the
|
|
same shape as input, and updated left caches for multiple attention and convolution
|
|
mudules.
|
|
|
|
Parameters
|
|
----------
|
|
src : torch.Tensor[torch.float32]
|
|
The input float tensor of shape (1, seq_len, embed_dim). The module input.
|
|
pos_emb : torch.Tensor[torch.float32]
|
|
A positional embedding tensor
|
|
of shape (1, left_context_len + 2 * seq_len - 1, pos_dim).
|
|
left_cached_key : torch.Tensor[torch.float32]
|
|
A cached attention key tensor of the left context
|
|
of shape (1, left_context_len, query_dim).
|
|
left_cached_nonlin_attn : torch.Tensor[torch.float32]
|
|
A left context cached attention tensor for the non-linear attention module
|
|
of shape (1, left_context_len, head_dim).
|
|
left_cached_val_1 : torch.Tensor[torch.float32]
|
|
A cached left context tensor for the first self-attention module
|
|
of shape (1, left_context_len, value_dim).
|
|
left_cached_val_2 : torch.Tensor[torch.float32]
|
|
A cached left context for the second self-attention module
|
|
of shape (1, left_context_len, value_dim).
|
|
left_cached_conv_1 : torch.Tensor[torch.float32]
|
|
A cached left context tensor for the first convolution module
|
|
of shape (1, embed_dim, left_cache_len).
|
|
left_cached_conv_2 : torch.Tensor[torch.float32]
|
|
A cached left context tensor for the second convolution module
|
|
of shape (1, embed_dim, left_cache_len).
|
|
src_key_padding_mask : torch.Tensor[torch.bool]
|
|
A boolean tensor of shape (1, seq_len_2). Positions that are True in this mask will be
|
|
ignored as sources in the attention weighting and convolution modules.
|
|
|
|
Returns
|
|
-------
|
|
tuple[
|
|
torch.Tensor[torch.float32],
|
|
torch.Tensor[torch.float32],
|
|
torch.Tensor[torch.float32],
|
|
torch.Tensor[torch.float32],
|
|
torch.Tensor[torch.float32],
|
|
torch.Tensor[torch.float32],
|
|
torch.Tensor[torch.float32],
|
|
]
|
|
A tuple of seven float tensors:
|
|
- The module output of shape (1, seq_len, embed_dim).
|
|
A tensor with the same shape as input.
|
|
- The updated left context cached attention key tensor
|
|
of shape (1, left_context_len, query_dim).
|
|
- The updated left context cached attention tensor for the non-linear attention module
|
|
of shape (1, left_context_len, head_dim).
|
|
- The updated cached left context for the first self-attention module
|
|
of shape (1, left_context_len, value_dim).
|
|
- The updated cached left context for the second self-attention module
|
|
of shape (1, left_context_len, value_dim).
|
|
- The updated cached left context for the first convolution module
|
|
of shape (1, embed_dim, left_cache_len).
|
|
- The updated cached left context for the second convolution module
|
|
of shape (1, embed_dim, left_cache_len).
|
|
"""
|
|
|
|
src_orig = src
|
|
|
|
# attn_weights: (1, num_heads, seq_len, seq_len_2)
|
|
attn_weights, left_cached_key = self.self_attn_weights(
|
|
src, pos_emb, left_cached_key, src_key_padding_mask,
|
|
)
|
|
src = src + self.feed_forward1(src)
|
|
|
|
na, left_cached_nonlin_attn = self.nonlin_attention(
|
|
src, attn_weights[:, 0], left_cached_nonlin_attn,
|
|
)
|
|
src = src + na
|
|
|
|
self_attn, left_cached_val_1 = self.self_attn1(src, attn_weights, left_cached_val_1)
|
|
src = src + self_attn
|
|
|
|
src_conv, left_cached_conv_1 = self.conv_module1(
|
|
src, left_cached_conv_1, src_key_padding_mask[:, self.left_context_len:],
|
|
)
|
|
src = src + src_conv
|
|
|
|
src = src + self.feed_forward2(src)
|
|
|
|
# bypass in the middle of the layer.
|
|
src = self.bypass_mid(src_orig, src)
|
|
|
|
self_attn, left_cached_val_2 = self.self_attn2(src, attn_weights, left_cached_val_2)
|
|
src = src + self_attn
|
|
|
|
src_conv, left_cached_conv_2 = self.conv_module2(
|
|
src, left_cached_conv_2, src_key_padding_mask[:, self.left_context_len:],
|
|
)
|
|
src = src + src_conv
|
|
|
|
src = src + self.feed_forward3(src)
|
|
|
|
src = self.norm(src)
|
|
src = self.bypass(src_orig, src)
|
|
|
|
return (
|
|
src,
|
|
left_cached_key,
|
|
left_cached_nonlin_attn,
|
|
left_cached_val_1,
|
|
left_cached_val_2,
|
|
left_cached_conv_1,
|
|
left_cached_conv_2,
|
|
)
|
|
|
|
|
|
class Zipformer2Encoder(torch.nn.Module):
|
|
"""
|
|
Zipformer2Encoder is a stack of Zipformer2EncoderLayer modules.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
encoder_layer: torch.nn.Module,
|
|
num_layers: int,
|
|
embed_dim: int,
|
|
pos_dim: int,
|
|
pos_max_len: int,
|
|
downsample: int,
|
|
device: torch.device,
|
|
) -> None:
|
|
"""
|
|
Zipformer2Encoder initialization.
|
|
|
|
Parameters
|
|
----------
|
|
encoder_layer : torch.nn.Module
|
|
An instance of the Zipformer2EncoderLayer class.
|
|
num_layers : int
|
|
The number of encoder Zipformer2EncoderLayer modules in the stack.
|
|
embed_dim : int
|
|
The input and output embedding dimension. The embedding dimension is the same for
|
|
input and output of this module.
|
|
pos_dim : int
|
|
The dimension for the relative positional embedding.
|
|
downsample : int
|
|
The downsampling factor of the module, the input will be downsampled in the beginning
|
|
and upsampled back at the end.
|
|
device : torch.device
|
|
The device used to store the layer weights. Should be
|
|
either torch.device("cpu") or torch.device("cuda").
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
self.num_layers = num_layers
|
|
self.downsample = SimpleDownsample(downsample, device)
|
|
self.encoder_pos = CompactRelPositionalEncoding(
|
|
pos_dim, pos_max_len, encoder_layer.left_context_len, device,
|
|
)
|
|
|
|
self.layers = torch.nn.ModuleList(
|
|
[copy.deepcopy(encoder_layer) for _ in range(num_layers)],
|
|
)
|
|
self.upsample = SimpleUpsample(downsample)
|
|
self.out_combiner = BypassModule(embed_dim, device)
|
|
|
|
def forward(
|
|
self,
|
|
src: torch.Tensor,
|
|
left_cached_keys: torch.Tensor,
|
|
left_cached_nonlin_attentions: torch.Tensor,
|
|
left_cached_values_1: torch.Tensor,
|
|
left_cached_values_2: torch.Tensor,
|
|
left_cached_convolutions_1: torch.Tensor,
|
|
left_cached_convolutions_2: torch.Tensor,
|
|
src_key_padding_mask: torch.Tensor,
|
|
) -> tuple[
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
]:
|
|
"""
|
|
Does a forward pass of the Zipformer2Encoder module. Returns an output tensor with the same
|
|
shape as input, and updated left caches for multiple attention and convolution mudules.
|
|
|
|
Parameters
|
|
----------
|
|
src : torch.Tensor[torch.float32]
|
|
The input float tensor of shape (1, seq_len, embed_dim). The module input.
|
|
left_cached_keys : torch.Tensor[torch.float32]
|
|
A cached attention key tensor of the left context for each Zipformer2EncoderLayer.
|
|
A tensor is of shape (num_layers, 1, left_context_len, query_dim).
|
|
left_cached_nonlin_attentions : torch.Tensor[torch.float32]
|
|
A left context cached attention tensor for the non-linear attention module of each
|
|
Zipformer2EncoderLayer. A tensor is
|
|
of shape (num_layers, 1, left_context_len, head_dim).
|
|
left_cached_values_1 : torch.Tensor[torch.float32]
|
|
A cached left context tensor for the first self-attention module of each
|
|
Zipformer2EncoderLayer. A tensor is
|
|
of shape (num_layers, 1, left_context_len, value_dim).
|
|
left_cached_values_2 : torch.Tensor[torch.float32]
|
|
A cached left context tensor for the second self-attention module of each
|
|
Zipformer2EncoderLayer. A tensor is
|
|
of shape (num_layers, 1, left_context_len, value_dim).
|
|
left_cached_convolutions_1 : torch.Tensor[torch.float32]
|
|
A cached left context tensor for the first convolution module of each
|
|
Zipformer2EncoderLayer. A tensor is
|
|
of shape (num_layers, 1, embed_dim, left_cache_len).
|
|
left_cached_convolutions_2 : torch.Tensor[torch.float32]
|
|
A cached left context tensor for the second convolution module of each
|
|
Zipformer2EncoderLayer. A tensor is
|
|
of shape (num_layers, 1, embed_dim, left_cache_len).
|
|
src_key_padding_mask : torch.Tensor[torch.bool]
|
|
A boolean tensor of shape (1, seq_len_2). Positions that are True in this mask will be
|
|
ignored as sources in the attention weighting and convolution modules.
|
|
|
|
Returns
|
|
-------
|
|
tuple[
|
|
torch.Tensor[torch.float32],
|
|
torch.Tensor[torch.float32],
|
|
torch.Tensor[torch.float32],
|
|
torch.Tensor[torch.float32],
|
|
torch.Tensor[torch.float32],
|
|
torch.Tensor[torch.float32],
|
|
torch.Tensor[torch.float32],
|
|
]
|
|
A tuple of seven float tensors:
|
|
- The module output of shape (1, seq_len, embed_dim).
|
|
A tensor with the same shape as input.
|
|
- The updated cached attention key tensor of the left context for each
|
|
Zipformer2EncoderLayer. A tensor is
|
|
of shape (num_layers, 1, left_context_len, query_dim).
|
|
- The updated left context cached attention tensor for the non-linear attention module
|
|
of each Zipformer2EncoderLayer. A tensor is
|
|
of shape (num_layers, 1, left_context_len, head_dim).
|
|
- The updated cached left context tensor for the first self-attention module of each
|
|
Zipformer2EncoderLayer. A tensor is
|
|
of shape (num_layers, 1, left_context_len, value_dim).
|
|
- The updated cached left context tensor for the second self-attention module of each
|
|
Zipformer2EncoderLayer. A tensor is
|
|
of shape (num_layers, 1, left_context_len, value_dim).
|
|
- The updated cached left context tensor for the first convolution module of each
|
|
Zipformer2EncoderLayer. A tensor is
|
|
of shape (num_layers, 1, embed_dim, left_cache_len).
|
|
- The updated cached left context tensor for the second convolution module of each
|
|
Zipformer2EncoderLayer. A tensor is
|
|
of shape (num_layers, 1, embed_dim, left_cache_len).
|
|
"""
|
|
# pylint: disable=too-many-locals
|
|
|
|
src_orig = src
|
|
src = self.downsample(src)
|
|
pos_emb = self.encoder_pos(src)
|
|
|
|
new_left_cached_keys = torch.empty(
|
|
left_cached_keys.shape, dtype=torch.float32, device=left_cached_keys.device,
|
|
)
|
|
new_left_cached_nonlin_attentions = torch.empty(
|
|
left_cached_nonlin_attentions.shape,
|
|
dtype=torch.float32,
|
|
device=left_cached_nonlin_attentions.device,
|
|
)
|
|
new_left_cached_values_1 = torch.empty(
|
|
left_cached_values_1.shape, dtype=torch.float32, device=left_cached_values_1.device,
|
|
)
|
|
new_left_cached_values_2 = torch.empty(
|
|
left_cached_values_2.shape, dtype=torch.float32, device=left_cached_values_2.device,
|
|
)
|
|
new_left_cached_convolutions_1 = torch.empty(
|
|
left_cached_convolutions_1.shape,
|
|
dtype=torch.float32,
|
|
device=left_cached_convolutions_1.device,
|
|
)
|
|
new_left_cached_convolutions_2 = torch.empty(
|
|
left_cached_convolutions_2.shape,
|
|
dtype=torch.float32,
|
|
device=left_cached_convolutions_2.device,
|
|
)
|
|
|
|
for i, mod in enumerate(self.layers):
|
|
(
|
|
src,
|
|
new_left_cached_keys[i],
|
|
new_left_cached_nonlin_attentions[i],
|
|
new_left_cached_values_1[i],
|
|
new_left_cached_values_2[i],
|
|
new_left_cached_convolutions_1[i],
|
|
new_left_cached_convolutions_2[i],
|
|
) = mod(
|
|
src,
|
|
pos_emb,
|
|
left_cached_keys[i],
|
|
left_cached_nonlin_attentions[i],
|
|
left_cached_values_1[i],
|
|
left_cached_values_2[i],
|
|
left_cached_convolutions_1[i],
|
|
left_cached_convolutions_2[i],
|
|
src_key_padding_mask,
|
|
)
|
|
|
|
src = self.upsample(src)
|
|
|
|
# Remove any extra frames that are not a multiple of downsample_factor
|
|
src = src[:, : src_orig.size(1)]
|
|
src = self.out_combiner(src_orig, src)
|
|
|
|
return (
|
|
src,
|
|
new_left_cached_keys,
|
|
new_left_cached_nonlin_attentions,
|
|
new_left_cached_values_1,
|
|
new_left_cached_values_2,
|
|
new_left_cached_convolutions_1,
|
|
new_left_cached_convolutions_2,
|
|
)
|
|
|
|
|
|
class BypassModule(torch.nn.Module):
|
|
"""
|
|
A bypass module that implements a learnable bypass scale for each input channel.
|
|
"""
|
|
|
|
def __init__(self, num_channels: int, device: torch.device) -> None:
|
|
"""
|
|
BypassModule initialization.
|
|
|
|
Parameters
|
|
----------
|
|
num_channels : int
|
|
The number of input channels, corresponds to the number of learnable bypass scales.
|
|
device : torch.device
|
|
The device used to store the layer weights. Should be
|
|
either torch.device("cpu") or torch.device("cuda").
|
|
"""
|
|
|
|
super().__init__()
|
|
self.bypass_scale = torch.nn.Parameter(
|
|
torch.ones(num_channels, dtype=torch.float32, device=device),
|
|
)
|
|
|
|
def forward(self, x_early: torch.Tensor, x_later: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Does a forward pass of the BypassModule module.
|
|
|
|
Parameters
|
|
----------
|
|
x_early : torch.Tensor[torch.float32]
|
|
The input float tensor of shape (1, seq_len, num_channels).
|
|
The module input that will be propagated with (1 - self.bypass_scale) weight.
|
|
x_later : torch.Tensor[torch.float32]
|
|
An input float tensor of shape (1, seq_len, num_channels).
|
|
The module input that will be propagated with self.bypass_scale weight.
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor[torch.float32]
|
|
A float tensor of shape (1, seq_len, num_channels). The shape is the same for x_early
|
|
and x_later. The output of the module is x_early bypassed and added to x_later.
|
|
"""
|
|
|
|
# It's just a slightly more efficient implementation of
|
|
# (1.0 - self.bypass_scale) * x_early + self.bypass_scale * x_later
|
|
return x_early + (x_later - x_early) * self.bypass_scale
|
|
|
|
|
|
class SimpleDownsample(torch.nn.Module):
|
|
"""
|
|
A downsample layer, does downsampling by weighted sum aggregation.
|
|
"""
|
|
|
|
def __init__(self, downsample: int, device: torch.device) -> None:
|
|
"""
|
|
SimpleDownsample initialization.
|
|
|
|
Parameters
|
|
----------
|
|
downsample : int
|
|
The module downsampling factor.
|
|
device : torch.device
|
|
The device used to store the layer weights.
|
|
Either torch.device("cpu") or torch.device("cuda").
|
|
"""
|
|
|
|
super().__init__()
|
|
self.weights = torch.nn.Parameter(
|
|
torch.zeros(downsample, 1, dtype=torch.float32, device=device),
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Does a forward pass of the SimpleDownsample module.
|
|
|
|
Parameters
|
|
----------
|
|
x : torch.Tensor[torch.float32]
|
|
The input float tensor of shape (1, seq_len, num_channels).
|
|
The module input that will be downsampled.
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor[torch.float32]
|
|
A float tensor of shape
|
|
(1, (seq_len + downsample - 1) // downsample, num_channels).
|
|
The downsampled output of the module.
|
|
"""
|
|
|
|
downsample = self.weights.size(0)
|
|
if downsample == 1:
|
|
return x
|
|
|
|
batch_size, seq_len, in_channels = x.size() # batch_size is 1
|
|
downsampled_seq_len = (seq_len + downsample - 1) // downsample
|
|
|
|
# Pad to an exact multiple of downsample. Right-pad x, repeating the last element.
|
|
pad_frames = downsampled_seq_len * downsample - seq_len
|
|
if pad_frames > 0:
|
|
pad = x[:, seq_len - 1:, :].expand(batch_size, pad_frames, in_channels)
|
|
x = torch.cat((x, pad), dim=1)
|
|
|
|
# (1, seq_len, in_channels) -> (1, seq_len // downsample, downsample, in_channels)
|
|
x = x.reshape(batch_size, downsampled_seq_len, downsample, in_channels)
|
|
|
|
x = torch.sum(x * self.weights, dim=2)
|
|
|
|
return x
|
|
|
|
|
|
class SimpleUpsample(torch.nn.Module):
|
|
"""
|
|
An upsample layer, does upsampling by repeating the input frames.
|
|
"""
|
|
|
|
def __init__(self, upsample: int) -> None:
|
|
"""
|
|
SimpleUpsample initialization.
|
|
|
|
Parameters
|
|
----------
|
|
upsample : int
|
|
The module upsampling factor.
|
|
"""
|
|
|
|
super().__init__()
|
|
self.upsample = upsample
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Does a forward pass of the SimpleUpsample module.
|
|
|
|
Parameters
|
|
----------
|
|
x : torch.Tensor[torch.float32]
|
|
The input float tensor of shape (1, seq_len, num_channels).
|
|
The module input that will be upsampled.
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor[torch.float32]
|
|
A float tensor of shape (1, seq_len * upsample, num_channels).
|
|
The upsampled output of the module.
|
|
"""
|
|
|
|
if self.upsample == 1:
|
|
return x
|
|
|
|
x = torch.repeat_interleave(x, self.upsample, dim=1)
|
|
|
|
return x
|
|
|
|
|
|
class CompactRelPositionalEncoding(torch.nn.Module):
|
|
"""
|
|
Relative positional encoding module. This version is "compact" meaning it is able to encode the
|
|
important information about the relative positions in a relatively small number of dimensions.
|
|
The goal is to make it so that small differences between large relative offsets
|
|
(e.g. 1000 vs. 1001) make very little difference to the embedding. Such differences were
|
|
potentially important when encoding absolute position, but not important when encoding relative
|
|
position because there is now no need to compare two large offsets with each other.
|
|
|
|
This implementation works by projecting the interval [-infinity, infinity] to a finite interval
|
|
using the torch.atan() function before doing the fourier transform of that fixed interval.
|
|
The torch.atan() function would compress the "long tails" too small, making it hard to
|
|
distinguish between different magnitudes of large offsets. To mitigate this a logarithmic
|
|
function is used to compress large offsets to a smaller range before applying torch.atan().
|
|
Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets
|
|
as long as they are quite close to the origin, e.g. abs(offset) <= sqrt(embedding_dim).
|
|
"""
|
|
|
|
def __init__(
|
|
self, embed_dim: int, max_length: int, left_context_len: int, device: torch.device,
|
|
) -> None:
|
|
"""
|
|
CompactRelPositionalEncoding initialization.
|
|
|
|
Parameters
|
|
----------
|
|
embed_dim : int
|
|
The positional embedding dimension.
|
|
max_length : int
|
|
The maximum length of the input that this module will be able to handle after
|
|
initialization without positional embeddings expansion. In case of longer input the
|
|
positional embeddings will be re-computed to adjust bigger length.
|
|
left_context_len : int
|
|
Length of cached left context.
|
|
device : torch.device
|
|
The device used to store the layer positional embeddings.
|
|
Should be either torch.device("cpu") or torch.device("cuda").
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
if embed_dim % 2 != 0:
|
|
raise ValueError(
|
|
'Embedding dimension for CompactRelPositionalEncoding '
|
|
f'should be an even number, but got {embed_dim}.',
|
|
)
|
|
|
|
self.embed_dim = embed_dim
|
|
self.left_context_len = left_context_len
|
|
self.pos_emb = self.create_pos_emb(max_length, device)
|
|
|
|
def create_pos_emb(self, max_length: int, device: torch.device) -> torch.Tensor:
|
|
"""
|
|
Creates a relative positional embeddings based on the maximum length.
|
|
This method is used to create positional embeddings with a
|
|
sufficiently long temporal axes during module initialization.
|
|
We want it to be big enough to avoid getting input x that is longer
|
|
than self.pos_emb during inference. On the other hand, we want
|
|
to initialize it with the smallest maximum length possible to consume
|
|
less memory.
|
|
|
|
Parameters
|
|
----------
|
|
max_length : int
|
|
The maximum length of the input that can be handeled by this layer. Increasing this
|
|
will let to process bigger input (speaking of temporal dimension), but will also
|
|
increase the memory consumption.
|
|
device : torch.device
|
|
The device used to store the positional embeddings.
|
|
Should be either torch.device("cpu") or torch.device("cuda").
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor[torch.float32]
|
|
A float tensor of shape (2 * max_length - 1, embed_dim).
|
|
Relative positional embeddings.
|
|
"""
|
|
|
|
# if max_length == 4, the x would contain [-3, -2, -1, 0, 1, 2, 3]
|
|
x = torch.arange(-max_length + 1, max_length, dtype=torch.float32, device=device)
|
|
|
|
# Compression length is an arbitrary heuristic, if it is larger we have more resolution for
|
|
# small time offsets but less resolution for large time offsets.
|
|
compression_length = self.embed_dim**0.5
|
|
|
|
# Compressing x within the next line of code, similarly to uncompressed x, it goes from
|
|
# -infinity to infinity as the sequence length goes from -infinity to infinity, but it does
|
|
# so more slowly than sequence length for the large absolute values of sequence length.
|
|
# The formula is chosen so that d(x_compressed) / dx is equal to 1 around x == 0,
|
|
# which is important.
|
|
x = compression_length * torch.sign(x) * torch.log(torch.abs(x) / compression_length + 1.0)
|
|
|
|
# results between -pi and pi
|
|
x = torch.atan(2.0 * torch.pi * x / self.embed_dim)
|
|
|
|
freqs = torch.arange(1, self.embed_dim // 2 + 1, dtype=torch.float32, device=device)
|
|
x = x.unsqueeze(1) * freqs
|
|
|
|
pos_emb = torch.zeros(x.size(0), self.embed_dim, dtype=torch.float32, device=device)
|
|
pos_emb[:, 0::2] = torch.cos(x)
|
|
pos_emb[:, 1::2] = torch.sin(x)
|
|
pos_emb[:, self.embed_dim - 1] = 1.0 # for bias.
|
|
|
|
return pos_emb
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Does a forward pass of the CompactRelPositionalEncoding module.
|
|
Returns a relative positional embeddings based on the input x temporal dimension.
|
|
|
|
Parameters
|
|
----------
|
|
x : torch.Tensor[torch.float32]
|
|
An input float tensor of shape (1, seq_len, embed_dim).
|
|
The module input. It's shape will be used to construct relative positional embeddings.
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor[torch.float32]
|
|
A float tensor of shape (1, self.left_context_len + 2 * seq_len - 1, embed_dim).
|
|
Relative positional embeddings.
|
|
"""
|
|
|
|
if self.pos_emb.size(0) < 2 * (x.size(1) + self.left_context_len) - 1:
|
|
self.pos_emb = self.create_pos_emb(x.size(1) + self.left_context_len, x.device)
|
|
|
|
# Length of negative side: x.size(1) + self.left_context_len.
|
|
# Length of positive side: x.size(1).
|
|
pos_emb = self.pos_emb[
|
|
self.pos_emb.size(0) // 2 - x.size(1) - self.left_context_len + 1:
|
|
self.pos_emb.size(0) // 2 + x.size(1)
|
|
].unsqueeze(0).repeat(x.size(0), 1, 1)
|
|
|
|
# (1, left_context_len + 2 * seq_len - 1, embed_dim),
|
|
# i. e. (batch_size, pos_len, embed_dim).
|
|
return pos_emb
|
|
|
|
|
|
class RelPositionMultiheadAttentionWeights(torch.nn.Module):
|
|
"""
|
|
Module that computes multi-head attention weights with relative position encoding.
|
|
Various other modules consume the resulting attention weights: see, for example,
|
|
the SelfAttention module which allows you to compute conventional self-attention.
|
|
|
|
This is a quite heavily modified from:
|
|
"Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context".
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
embed_dim: int,
|
|
pos_dim: int,
|
|
num_heads: int,
|
|
query_head_dim: int,
|
|
pos_head_dim: int,
|
|
right_context: int,
|
|
device: torch.device,
|
|
) -> None:
|
|
"""
|
|
RelPositionMultiheadAttentionWeights initialization.
|
|
|
|
Parameters
|
|
----------
|
|
embed_dim : int
|
|
The embedding dimension. The number of channels at the input to this module.
|
|
pos_dim : int
|
|
A dimension of the positional embeddings.
|
|
num_heads : int
|
|
The number of attention heads to compute weights.
|
|
query_head_dim : int
|
|
The dimension of the query and key per head.
|
|
pos_head_dim : int
|
|
The dimension of the projected positional encoding per head.
|
|
right_context : int
|
|
The module look ahead future context, used to update left
|
|
cached attention key correctly.
|
|
device : torch.device
|
|
The device used to store the layer positional embeddings. Should be
|
|
either torch.device("cpu") or torch.device("cuda").
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
self.embed_dim = embed_dim
|
|
self.num_heads = num_heads
|
|
self.query_head_dim = query_head_dim
|
|
self.pos_head_dim = pos_head_dim
|
|
self.right_context = right_context
|
|
|
|
in_proj_dim = (2 * query_head_dim + pos_head_dim) * num_heads
|
|
self.in_proj = torch.nn.Linear(embed_dim, in_proj_dim, device=device)
|
|
|
|
# Linear transformation for positional encoding.
|
|
self.linear_pos = torch.nn.Linear(
|
|
pos_dim, num_heads * pos_head_dim, bias=False, device=device,
|
|
)
|
|
|
|
def forward(
|
|
self, x: torch.Tensor,
|
|
pos_emb: torch.Tensor,
|
|
left_cached_key: torch.Tensor,
|
|
key_padding_mask: torch.Tensor,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Does a forward pass of the RelPositionMultiheadAttentionWeights module.
|
|
Returns attention weights and updated cached attention key tensor of the left context.
|
|
|
|
Parameters
|
|
----------
|
|
x : torch.Tensor[torch.float32]
|
|
The input float tensor of shape (1, seq_len, embed_dim). The module input.
|
|
pos_emb : torch.Tensor[torch.float32]
|
|
A positional embedding tensor
|
|
of shape (1, left_context_len + 2 * seq_len - 1, pos_dim).
|
|
left_cached_key : torch.Tensor[torch.float32]
|
|
A cached attention key tensor of the left context
|
|
of shape (1, left_context_len, key_dim).
|
|
key_padding_mask : torch.Tensor[torch.bool]
|
|
A boolean tensor of shape (1, seq_len_2). Positions that are True in this mask will be
|
|
ignored as sources in the attention weighting.
|
|
|
|
Returns
|
|
-------
|
|
tuple[torch.Tensor[torch.float32], torch.Tensor[torch.float32]]
|
|
A tuple of two float tensors:
|
|
- attention weights, of shape (1, hum_heads, seq_len, seq_len_2)
|
|
interpreted as (1, hum_heads, tgt_seq_len, src_seq_len).
|
|
- updated cached attention key tensor of the left context
|
|
of shape (1, left_context_len, key_dim).
|
|
"""
|
|
# pylint: disable=too-many-locals
|
|
|
|
batch_size = x.size(0) # batch size is 1
|
|
seq_len = x.size(1)
|
|
x = self.in_proj(x)
|
|
|
|
query_dim = self.query_head_dim * self.num_heads
|
|
|
|
# Self-attention.
|
|
q = x[:, :, :query_dim]
|
|
k = x[:, :, query_dim: 2 * query_dim]
|
|
# p is the position-encoding query.
|
|
p = x[:, :, 2 * query_dim:]
|
|
|
|
# Pad key with cached left context.
|
|
k = torch.cat((left_cached_key, k), dim=1)
|
|
# Update cached left contexts
|
|
seq_len_2 = k.size(1) # left_context_len + seq_len
|
|
left_cached_key = k[
|
|
:,
|
|
seq_len_2 - self.right_context - left_cached_key.size(1):
|
|
seq_len_2 - self.right_context,
|
|
]
|
|
|
|
q = q.reshape(batch_size, seq_len, self.num_heads, self.query_head_dim)
|
|
p = p.reshape(batch_size, seq_len, self.num_heads, self.pos_head_dim)
|
|
k = k.reshape(batch_size, seq_len_2, self.num_heads, self.query_head_dim)
|
|
|
|
# seq_len refers to target, seq_len_2 refers to source.
|
|
q = q.permute(0, 2, 1, 3) # (1, hum_heads, seq_len, query_head_dim)
|
|
p = p.permute(0, 2, 1, 3) # (1, hum_heads, seq_len, pos_head_dim)
|
|
k = k.permute(0, 2, 3, 1) # (1, hum_heads, key_head_dim, seq_len_2)
|
|
|
|
attn_scores = torch.matmul(q, k) # (1, hum_heads, seq_len, seq_len_2)
|
|
|
|
pos_len = pos_emb.size(1) # left_context_len + 2 * seq_len - 1
|
|
# (1, pos_len, num_heads * pos_head_dim)
|
|
pos_emb = self.linear_pos(pos_emb)
|
|
pos_emb = pos_emb.reshape(
|
|
batch_size, pos_len, self.num_heads, self.pos_head_dim,
|
|
).permute(0, 2, 3, 1) # (1, hum_heads, pos_head_dim, pos_len)
|
|
|
|
# (1, hum_heads, seq_len, pos_head_dim) x (1, hum_heads, pos_head_dim, pos_len) ->
|
|
# -> (1, hum_heads, seq_len, pos_len) where pos_len represents relative position.
|
|
pos_scores = torch.matmul(p, pos_emb)
|
|
|
|
# Now we need to perform the relative shift of the pos_scores, to do that we need to add
|
|
# a column of zeros to the left side of the last dimension and perform the relative shift.
|
|
pos_scores_pad = torch.zeros(
|
|
pos_scores.size(0), pos_scores.size(1), pos_scores.size(2), 1,
|
|
dtype=torch.float32,
|
|
device=pos_scores.device,
|
|
)
|
|
# (1, hum_heads, seq_len, pos_len + 1)
|
|
pos_scores = torch.cat((pos_scores_pad, pos_scores), dim=3)
|
|
pos_scores = pos_scores.reshape(
|
|
batch_size, self.num_heads, pos_len + 1, seq_len,
|
|
) # (1, hum_heads, pos_len + 1, seq_len)
|
|
# Now drop the extra row that had been added over padding and reshape.
|
|
pos_scores = pos_scores[:, :, 1:].reshape(
|
|
batch_size, self.num_heads, seq_len, pos_len,
|
|
) # (1, hum_heads, seq_len, pos_len)
|
|
|
|
# (1, hum_heads, seq_len, seq_len_2)
|
|
attn_scores = attn_scores + pos_scores[:, :, :, : attn_scores.size(3)]
|
|
|
|
# (1, seq_len_2) -> (1, 1, 1, seq_len_2) to make it broadcastable to attn_scores shape.
|
|
key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(2)
|
|
|
|
attn_scores = torch.masked_fill(attn_scores, key_padding_mask, -1000.0)
|
|
attn_weights = torch.softmax(attn_scores, dim=3)
|
|
|
|
return attn_weights, left_cached_key
|
|
|
|
|
|
class SelfAttention(torch.nn.Module):
|
|
"""
|
|
The simplest possible attention module. This one works with pre-computed attention weights,
|
|
e.g. as computed by RelPositionMultiheadAttentionWeights.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
embed_dim: int,
|
|
num_heads: int,
|
|
value_head_dim: int,
|
|
right_context: int,
|
|
device: torch.device,
|
|
) -> None:
|
|
"""
|
|
SelfAttention initialization.
|
|
|
|
Parameters
|
|
----------
|
|
embed_dim : int
|
|
The input and output embedding dimension. The number of channels is the same for input
|
|
and output of this module.
|
|
num_heads : int
|
|
The number of attention heads.
|
|
value_head_dim : int
|
|
The dimension of the value per head.
|
|
right_context : int
|
|
The module look ahead future context, used to update left cached
|
|
attention value correctly.
|
|
device : torch.device
|
|
The device used to store the layer positional embeddings.
|
|
Either torch.device("cpu") or torch.device("cuda").
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
self.in_proj = torch.nn.Linear(embed_dim, num_heads * value_head_dim, device=device)
|
|
self.out_proj = torch.nn.Linear(num_heads * value_head_dim, embed_dim, device=device)
|
|
self.right_context = right_context
|
|
|
|
def forward(
|
|
self, x: torch.Tensor, attn_weights: torch.Tensor, left_cached_val: torch.Tensor,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Does a forward pass of the SelfAttention module. Returns attention weighted input tensor
|
|
and updated cached attention value tensor of the left context.
|
|
|
|
Parameters
|
|
----------
|
|
x : torch.Tensor[torch.float32]
|
|
The input float tensor of shape (1, seq_len, embed_dim). The module input.
|
|
attn_weights : torch.Tensor[torch.float32]
|
|
The tensor of shape (1, num_heads, seq_len, seq_len_2), with (seq_len, seq_len_2)
|
|
being interpreted as (tgt_seq_len, src_seq_len). Expect attn_weights.sum(dim=3) == 1.0.
|
|
left_cached_val : torch.Tensor[torch.float32]
|
|
The cached attention value tensor of the left context
|
|
of shape (1, left_context_len, value_dim).
|
|
|
|
Returns
|
|
-------
|
|
tuple[torch.Tensor[torch.float32], torch.Tensor[torch.float32]]
|
|
A tuple of two float tensors:
|
|
- attention weighted output of shape (1, seq_len, embed_dim).
|
|
A tensor with the same shape as input x.
|
|
- updated cached attention value tensor of the left context
|
|
of shape (1, left_context_len, value_dim).
|
|
"""
|
|
|
|
batch_size = x.size(0) # batch size is 1
|
|
num_heads = attn_weights.size(1)
|
|
|
|
x = self.in_proj(x) # (1, seq_len, num_heads * value_head_dim)
|
|
|
|
x = torch.cat((left_cached_val, x), dim=1)
|
|
# Update cached left contexts
|
|
left_cached_val = x[
|
|
:,
|
|
x.size(1) - self.right_context - left_cached_val.size(1):
|
|
x.size(1) - self.right_context,
|
|
]
|
|
|
|
x = x.reshape(batch_size, x.size(1), num_heads, x.size(2) // num_heads).permute(0, 2, 1, 3)
|
|
|
|
# (1, num_heads, seq_len, seq_len_2) x (1, num_heads, seq_len_2, value_head_dim) ->
|
|
# -> (1, num_heads, seq_len, value_head_dim)
|
|
x = torch.matmul(attn_weights, x)
|
|
|
|
# (1, num_heads, seq_len, value_head_dim) -> (1, seq_len, num_heads, value_head_dim)
|
|
x = x.permute(0, 2, 1, 3)
|
|
x = x.reshape(batch_size, x.size(1), num_heads * x.size(3))
|
|
|
|
# returned value is of shape (1, seq_len, embed_dim), like the input.
|
|
x = self.out_proj(x)
|
|
|
|
return x, left_cached_val
|
|
|
|
|
|
class FeedforwardModule(torch.nn.Module):
|
|
"""
|
|
Feedforward module in Zipformer2 encoder.
|
|
"""
|
|
|
|
def __init__(self, embed_dim: int, feedforward_dim: int, device: torch.device) -> None:
|
|
"""
|
|
FeedforwardModule initialization.
|
|
|
|
Parameters
|
|
----------
|
|
embed_dim : int
|
|
The input and output embedding dimension. The number of channels is the same for input
|
|
and output of this module.
|
|
feedforward_dim : int
|
|
The module hidden dimension.
|
|
device : torch.device
|
|
The device used to store the layer weights. should be
|
|
either torch.device("cpu") or torch.device("cuda").
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
self.in_proj = torch.nn.Linear(embed_dim, feedforward_dim, device=device)
|
|
self.activation = SwooshL()
|
|
self.out_proj = torch.nn.Linear(feedforward_dim, embed_dim, device=device)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Does a forward pass of the FeedforwardModule module.
|
|
|
|
Parameters
|
|
----------
|
|
x : torch.Tensor[torch.float32]
|
|
A float tensor of shape (1, seq_len, embed_dim). The module input.
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor[torch.float32]
|
|
A float tensor of shape (1, seq_len, embed_dim).
|
|
The module output has the same shape as input.
|
|
"""
|
|
|
|
x = self.in_proj(x)
|
|
x = self.activation(x)
|
|
x = self.out_proj(x)
|
|
|
|
return x
|
|
|
|
|
|
class NonlinAttention(torch.nn.Module):
|
|
"""
|
|
This is like the ConvolutionModule, but refactored so that we use multiplication by attention
|
|
weights (borrowed from the RelPositionMultiheadAttentionWeights module) instead of actual
|
|
convolution. We also took out the second nonlinearity, the one after the attention mechanism.
|
|
"""
|
|
|
|
def __init__(
|
|
self, embed_dim: int, att_dim: int, right_context: int, device: torch.device,
|
|
) -> None:
|
|
"""
|
|
NonlinAttention initialization.
|
|
|
|
Parameters
|
|
----------
|
|
embed_dim : int
|
|
The input and output embedding dimension. The number of channels is the same for input
|
|
and output of this module.
|
|
att_dim : int
|
|
The attention output dimension of this module.
|
|
right_context : int
|
|
The module look ahead future context, used to update left cache
|
|
correctly.
|
|
device : torch.device
|
|
The device used to store the positional embeddings.
|
|
Should be either torch.device("cpu") or torch.device("cuda").
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
self.in_proj = torch.nn.Linear(embed_dim, att_dim * 3, device=device)
|
|
self.out_proj = torch.nn.Linear(att_dim, embed_dim, device=device)
|
|
self.right_context = right_context
|
|
|
|
def forward(
|
|
self, x: torch.Tensor, attn_weights: torch.Tensor, left_cached_x: torch.Tensor,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Does a forward pass of the NonlinAttention module. Returns attention weighted input tensor
|
|
and updated attention input tensor cache of the left context.
|
|
|
|
Parameters
|
|
----------
|
|
x : torch.Tensor[torch.float32]
|
|
An input float tensor of shape (1, seq_len, embed_dim).
|
|
attn_weights : torch.Tensor[torch.float32]
|
|
A tensor of shape (1, seq_len, seq_len_2), that corresponds to a single attention head
|
|
with (seq_len, seq_len_2) being interpreted as (tgt_seq_len, src_seq_len).
|
|
Expected attn_weights.sum(dim=2) == 1.0.
|
|
Note: the first dimension here corresponds to a batch size.
|
|
left_cached_x : torch.Tensor[torch.float32]
|
|
A cached attention tensor of the left context of shape (1, left_context_len, att_dim).
|
|
|
|
Returns
|
|
-------
|
|
tuple[torch.Tensor[torch.float32], torch.Tensor[torch.float32]]
|
|
A tuple of two float tensors:
|
|
- attention weighted output of shape (1, seq_len, embed_dim).
|
|
A tensor with the same shape as input x.
|
|
- updated cached attention tensor of the left context
|
|
of shape (1, left_context_len, att_dim).
|
|
"""
|
|
|
|
x = self.in_proj(x)
|
|
|
|
s, x, y = x.chunk(3, dim=2)
|
|
|
|
x = x * torch.tanh(s)
|
|
|
|
x = torch.cat((left_cached_x, x), dim=1)
|
|
# Update cached tensor
|
|
left_cached_x = x[
|
|
:,
|
|
x.size(1) - self.right_context - left_cached_x.size(1):
|
|
x.size(1) - self.right_context,
|
|
]
|
|
|
|
# (1, seq_len, seq_len_2) x (1, seq_len_2, att_dim) -> (1, seq_len, att_dim)
|
|
x = torch.matmul(attn_weights, x)
|
|
x = x * y
|
|
|
|
x = self.out_proj(x)
|
|
|
|
return x, left_cached_x
|
|
|
|
|
|
class ConvolutionModule(torch.nn.Module):
|
|
"""
|
|
ConvolutionModule in Zipformer2 encoder.
|
|
"""
|
|
|
|
def __init__(
|
|
self, embed_dim: int, kernel_size: int, right_context: int, device: torch.device,
|
|
) -> None:
|
|
"""
|
|
ConvolutionModule initialization.
|
|
|
|
Parameters
|
|
----------
|
|
embed_dim : int
|
|
The input and output embedding dimension, also the number of channels of convolution
|
|
modules. The embedding dmension is the same for input and output of this module.
|
|
kernel_size : int
|
|
The kernel size of the depthwise convolution module.
|
|
right_context : int
|
|
The module look ahead future context, used to update
|
|
causal depthwise convolution left cache correctly.
|
|
device : torch.device
|
|
The device used to store the layer weights. Should be
|
|
either torch.device("cpu") or torch.device("cuda").
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
if kernel_size % 2 == 0:
|
|
raise ValueError(
|
|
'ConvolutionModule kernerl size should be '
|
|
f'an odd number but got {kernel_size} instead.',
|
|
)
|
|
|
|
self.in_proj = torch.nn.Linear(embed_dim, 2 * embed_dim, device=device)
|
|
self.depthwise_conv = ChunkCausalDepthwiseConv1d(
|
|
embed_dim, kernel_size, right_context, device,
|
|
)
|
|
|
|
self.activation = SwooshR()
|
|
self.out_proj = torch.nn.Linear(embed_dim, embed_dim, device=device)
|
|
|
|
def forward(
|
|
self, x: torch.Tensor, left_cache: torch.Tensor, src_key_padding_mask: torch.Tensor,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Does a forward pass of the ConvolutionModule module. Returns processed tensor of the same
|
|
shape as input and updated cached convolution tensor of the left context.
|
|
|
|
Parameters
|
|
----------
|
|
x : torch.Tensor[torch.float32]
|
|
The input float tensor of shape (1, seq_len, embed_dim). The module input.
|
|
left_cache : torch.Tensor[torch.float32]
|
|
A cached convolution tensor of the left context
|
|
of shape (1, embed_dim, left_cache_len).
|
|
src_key_padding_mask : torch.Tensor[torch.bool]
|
|
The mask for the source keys of shape (1, seq_len),
|
|
contains True in masked positions that will be ignored.
|
|
|
|
Returns
|
|
-------
|
|
tuple[torch.Tensor[torch.float32], torch.Tensor[torch.float32]]
|
|
A tuple of two float tensors:
|
|
- module output of shape (1, seq_len, embed_dim).
|
|
A tensor with the same shape as input x.
|
|
- updated cached convolution tensor of the left context
|
|
of shape (1, embed_dim, left_cache_len).
|
|
"""
|
|
|
|
x = self.in_proj(x) # (1, seq_len, 2 * embed_dim)
|
|
|
|
x, s = x.chunk(2, dim=2)
|
|
x = x * torch.sigmoid(s) # (1, seq_len, embed_dim)
|
|
|
|
x = torch.masked_fill(x, src_key_padding_mask.unsqueeze(2), 0.0)
|
|
|
|
# exchange the temporal dimension and the feature dimension for depthwise convolution.
|
|
x = x.permute(0, 2, 1) # (1, embed_dim, seq_len).
|
|
x, left_cache = self.depthwise_conv(x, left_cache)
|
|
x = x.permute(0, 2, 1) # (1, seq_len, embed_dim)
|
|
|
|
x = self.activation(x)
|
|
x = self.out_proj(x) # (1, seq_len, embed_dim)
|
|
|
|
return x, left_cache
|
|
|
|
|
|
def _test_zipformer_main(causal: bool = False):
|
|
batch_size = 5
|
|
seq_len = 20
|
|
# Just make sure the forward pass runs.
|
|
|
|
c = Zipformer2(
|
|
encoder_dim=(64, 96),
|
|
encoder_unmasked_dim=(48, 64),
|
|
num_heads=(4, 4),
|
|
causal=causal,
|
|
chunk_size=(4,) if causal else (-1,),
|
|
left_context_frames=(64,),
|
|
)
|
|
batch_size = 5
|
|
seq_len = 20
|
|
# Just make sure the forward pass runs.
|
|
f = c(
|
|
torch.randn(seq_len, batch_size, 64),
|
|
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
|
)
|
|
f[0].sum().backward()
|
|
c.eval()
|
|
f = c(
|
|
torch.randn(seq_len, batch_size, 64),
|
|
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
|
)
|
|
f # to remove flake8 warnings
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logging.getLogger().setLevel(logging.INFO)
|
|
torch.set_num_threads(1)
|
|
torch.set_num_interop_threads(1)
|
|
_test_zipformer_main(False)
|
|
_test_zipformer_main(True)
|