2024-10-17 00:09:54 +00:00

2007 lines
82 KiB
Python

#!/usr/bin/env python3
# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
import math
import random
import warnings
from typing import List, Optional, Tuple, Union
import torch
from encoder_interface import EncoderInterface
from scaling import (
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
)
from scaling import (
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
)
from scaling import (
ActivationDropoutAndLinear,
Balancer,
BiasNorm,
ChunkCausalDepthwiseConv1d,
Dropout2,
FloatLike,
ScheduledFloat,
Whiten,
convert_num_channels,
limit_param_value,
penalize_abs_values_gt,
softmax,
)
from torch import Tensor, nn
class Zipformer2(torch.nn.Module):
"""
Zipformer2 encoder.
"""
# pylint: disable=too-many-instance-attributes
def __init__(
self,
input_dim: int,
subsample_output_dim: int,
subsample_layer1_channels: int,
subsample_layer2_channels: int,
subsample_layer3_channels: int,
encoder_dims: list[int],
num_encoder_layers: list[int],
downsampling_factors: list[int],
num_heads: list[int],
feedforward_dims: list[int],
cnn_module_kernels: list[int],
query_head_dim: int,
pos_head_dim: int,
value_head_dim: int,
pos_dim: int,
pos_max_len: int,
output_dim: int,
use_ctc: bool,
left_context_frames: int,
right_context_frames: int,
device: torch.device,
) -> None:
"""
Zipformer2 initialization.
Parameters
----------
input_dim : int
The number of input features.
subsample_output_dim : int
The output dimension of the subsampling module represented by Conv2dSubsampling.
subsample_layer1_channels : int
The number of output channels in the first Conv2d layer of the
Conv2dSubsampling module.
subsample_layer2_channels : int
The number of output channels in the second Conv2d layer of the
Conv2dSubsampling module.
subsample_layer3_channels : int
The number of output channels in the third Conv2d layer of the
Conv2dSubsampling module.
encoder_dims : list[int]
A list of 5 integers, the embedding dimension of
Zipformer2EncoderLayer module in each Zipformer2Encoder stack.
num_encoder_layers : list[int]
A list of 5 integers, the number of Zipformer2EncoderLayer
modules in each Zipformer2Encoder stack.
downsampling_factors : list[int]
A list of 5 integers, the downsampling factor of each Zipformer2Encoder stack.
Note: this is in addition to the downsampling factor of 2 that is applied in the
Conv2dSubsampling module.
num_heads : list[int]
A list of 5 integers, the number of heads for attention weights and self-attention of
the Zipformer2EncoderLayer module in each Zipformer2Encoder stack.
feedforward_dims : list[int]
A list of 5 integers, the hidden dimension of the feedforward module of
the Zipformer2EncoderLayer module in each Zipformer2Encoder stack.
cnn_module_kernels : list[int]
A list of 5 integers, the kernel size of the convolution module of
the Zipformer2EncoderLayer module in each Zipformer2Encoder stack.
query_head_dim : int
The dimension of the query and key per attention head in attention weights of the
Zipformer2EncoderLayer module in each Zipformer2Encoder stack.
pos_head_dim : int
The dimension of the projected positional encoding per attention head in attention
weights of the Zipformer2EncoderLayer module in each Zipformer2Encoder stack.
value_head_dim : int
The dimension of the value per attention head in self-attention of
the Zipformer2EncoderLayer module in each Zipformer2Encoder stack.
pos_dim: int
The dimension of the relative positional embeddings in each Zipformer2Encoder stack.
pos_max_len : int
The maximum input duration of the relative positional embeddings in each
Zipformer2Encoder stack. Note: if the input duration of any positional embedding module
exceeds this number, then one might end up with a big degradation of inference speed.
output_dim : int
The output dimension after final output projection.
use_ctc : bool
If True, assuming that ctc head will loaded to the output encoder projection.
In this case torch.nn.functional. will be applied to the output at the very end.
left_context_frames : int
The left context number of frames after the initial subsampling with
Conv2dSubsampling module.
right_context_frames : int
The right (look-ahead) context number of frames.
device : torch.device
The device used to store the layer weights. Should be
either torch.device("cpu") or torch.device("cuda").
"""
# pylint: disable=too-many-arguments,too-many-locals
super().__init__()
if not (
len(encoder_dims)
== len(num_encoder_layers)
== len(downsampling_factors)
== len(num_heads)
== len(feedforward_dims)
== len(cnn_module_kernels)
== 6
):
raise ValueError(
'It is required that the length of encoder_dims, num_encoder_layers, '
'downsampling_factors, num_heads, feedforward_dims, and cnn_module_kernels is the '
'same and equal to 6, but got following list lengths:\n'
f'len(num_encoder_layers) == {len(num_encoder_layers)}\n'
f'len(downsampling_factors) == {len(downsampling_factors)}\n'
f'len(encoder_dims) == {len(encoder_dims)}\n'
f'len(num_heads) == {len(num_heads)}\n'
f'len(cnn_module_kernels) == {len(cnn_module_kernels)}\n'
f'len(feedforward_dims) == {len(feedforward_dims)}.',
)
self.encoder_dims = tuple(encoder_dims)
self.downsampling_factors = tuple(downsampling_factors)
self.left_context_frames = left_context_frames
projection_dim = max(encoder_dims)
self.projection_dim = projection_dim
self.ctc = use_ctc
self.subsampling = Conv2dSubsampling(
input_dim,
subsample_output_dim,
subsample_layer1_channels,
subsample_layer2_channels,
subsample_layer3_channels,
right_context_frames,
device,
)
encoders = []
for i, num_layers in enumerate(num_encoder_layers):
encoder_layer = Zipformer2EncoderLayer(
encoder_dims[i],
pos_dim,
num_heads[i],
query_head_dim,
pos_head_dim,
value_head_dim,
feedforward_dims[i],
cnn_module_kernels[i],
left_context_frames // downsampling_factors[i],
right_context_frames // 2 // downsampling_factors[i],
device,
)
encoder = Zipformer2Encoder(
encoder_layer,
num_layers,
encoder_dims[i],
pos_dim,
pos_max_len,
downsampling_factors[i],
device,
)
encoders.append(encoder)
self.encoder_1 = encoders[0]
self.encoder_2 = encoders[1]
self.encoder_3 = encoders[2]
self.encoder_4 = encoders[3]
self.encoder_5 = encoders[4]
self.encoder_6 = encoders[5]
self.downsample_output = SimpleDownsample(2, device)
self.projection_output = torch.nn.Linear(projection_dim, output_dim, device=device)
def forward(
self,
x: torch.Tensor,
# We need to preserve this explicit arguments reference for the smooth
# TirchScript export with the following ONNX export.
left_cached_subsample_frames: torch.Tensor,
left_cached_keys_encoder_1: torch.Tensor,
left_cached_nonlin_attentions_encoder_1: torch.Tensor,
left_cached_values_1_encoder_1: torch.Tensor,
left_cached_values_2_encoder_1: torch.Tensor,
left_cached_convolutions_1_encoder_1: torch.Tensor,
left_cached_convolutions_2_encoder_1: torch.Tensor,
left_cached_keys_encoder_2: torch.Tensor,
left_cached_nonlin_attentions_encoder_2: torch.Tensor,
left_cached_values_1_encoder_2: torch.Tensor,
left_cached_values_2_encoder_2: torch.Tensor,
left_cached_convolutions_1_encoder_2: torch.Tensor,
left_cached_convolutions_2_encoder_2: torch.Tensor,
left_cached_keys_encoder_3: torch.Tensor,
left_cached_nonlin_attentions_encoder_3: torch.Tensor,
left_cached_values_1_encoder_3: torch.Tensor,
left_cached_values_2_encoder_3: torch.Tensor,
left_cached_convolutions_1_encoder_3: torch.Tensor,
left_cached_convolutions_2_encoder_3: torch.Tensor,
left_cached_keys_encoder_4: torch.Tensor,
left_cached_nonlin_attentions_encoder_4: torch.Tensor,
left_cached_values_1_encoder_4: torch.Tensor,
left_cached_values_2_encoder_4: torch.Tensor,
left_cached_convolutions_1_encoder_4: torch.Tensor,
left_cached_convolutions_2_encoder_4: torch.Tensor,
left_cached_keys_encoder_5: torch.Tensor,
left_cached_nonlin_attentions_encoder_5: torch.Tensor,
left_cached_values_1_encoder_5: torch.Tensor,
left_cached_values_2_encoder_5: torch.Tensor,
left_cached_convolutions_1_encoder_5: torch.Tensor,
left_cached_convolutions_2_encoder_5: torch.Tensor,
left_cached_keys_encoder_6: torch.Tensor,
left_cached_nonlin_attentions_encoder_6: torch.Tensor,
left_cached_values_1_encoder_6: torch.Tensor,
left_cached_values_2_encoder_6: torch.Tensor,
left_cached_convolutions_1_encoder_6: torch.Tensor,
left_cached_convolutions_2_encoder_6: torch.Tensor,
processed_len: torch.Tensor,
) -> tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor, torch.Tensor,
]:
"""
Does a forward pass of the Zipformer2 module, which represents the whole acoustic encoder.
Returns a tuple with the output tensor, updated left cache feature tensor for subsampling
module, 36 left cache tensors for multiple attention and convolution modules within each of
6 Zipformer2Encoder modules, and finally, the updated processed length single-element
tensor with the total number of processed frames after subsampling module.
Parameters
----------
x : torch.Tensor[torch.float32]
The input float feature tensor of shape (1, num_frames, input_dim),
where the input_dim corresponds to the number of features.
left_cached_subsample_frames : torch.Tensor[torch.float32]
The subsampling module left cache tensor of shape (1, 10, input_dim).
left_cached_keys_encoder_1 : torch.Tensor[torch.float32]
The cached attention key tensor of the left context of each
Zipformer2EncoderLayer within the first Zipformer2Encoder.
The tensor is of shape (num_layers_1, 1, left_context_len_1, query_dim_1).
left_cached_nonlin_attentions_encoder_1 : torch.Tensor[torch.float32]
The left context cached attention tensor for the non-linear attention module of each
Zipformer2EncoderLayer within the first Zipformer2Encoder.
The tensor is of shape (num_layers_1, 1, left_context_len_1, head_dim_1).
left_cached_values_1_encoder_1 : torch.Tensor[torch.float32]
The cached left context tensor for the first self-attention module of each
Zipformer2EncoderLayer within the first Zipformer2Encoder.
The tensor is of shape (num_layers_1, 1, left_context_len_1, value_dim_1).
left_cached_values_2_encoder_1 : torch.Tensor[torch.float32]
The cached left context tensor for the second self-attention module of each
Zipformer2EncoderLayer within the first Zipformer2Encoder.
The tensor is of shape (num_layers_1, 1, left_context_len_1, value_dim_1).
left_cached_convolutions_1_encoder_1 : torch.Tensor[torch.float32]
The cached left context tensor for the first convolution module of each
Zipformer2EncoderLayer within the first Zipformer2Encoder.
The tensor is of shape (num_layers_1, 1, embed_dim_1, left_cache_len_1).
left_cached_convolutions_2_encoder_1 : torch.Tensor[torch.float32]
The cached left context tensor for the second convolution module of each
Zipformer2EncoderLayer within the first Zipformer2Encoder.
The tensor is of shape (num_layers_1, 1, embed_dim_1, left_cache_len_1).
.
.
.
left_cached_convolutions_2_encoder_6 : torch.Tensor[torch.float32]
The cached left context tensor for the second convolution module of each
Zipformer2EncoderLayer within the sixth Zipformer2Encoder.
The tensor is of shape (num_layers_6, 1, embed_dim_6, left_cache_len_6).
processed_len : torch.Tensor[torch.int32]
The total processed length after subsampling, single-element integer tensor
of shape (1,).
Returns
-------
tuple[
torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32],
torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32],
torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32],
torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32],
torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32],
torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32],
torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32],
torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32],
torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32],
torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32],
torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32],
torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32],
torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.int32],
]
A tuple of 38 float tensors and 1 integer tensor:
- The module output of shape (1, seq_len, output_dim).
- The updated subsampling module left cache tensor of shape (1, 10, input_dim).
- The updated cached attention key tensor of the left context of each
Zipformer2EncoderLayer within the first Zipformer2Encoder.
The tensor is of shape (num_layers_1, 1, left_context_len_1, query_dim_1).
- The updated left context cached attention tensor for the non-linear attention
module of each Zipformer2EncoderLayer within the first Zipformer2Encoder.
The tensor is of shape (num_layers_1, 1, left_context_len_1, head_dim_1).
- The updated cached left context tensor for the first self-attention module of each
Zipformer2EncoderLayer within the first Zipformer2Encoder.
The tensor is of shape (num_layers_1, 1, left_context_len_1, value_dim_1).
- The updated cached left context tensor for the second
self-attention module of each Zipformer2EncoderLayer within the first
Zipformer2Encoder.
The tensor is of shape (num_layers_1, 1, left_context_len_1, value_dim_1).
- The updated cached left context tensor for the first convolution module of each
Zipformer2EncoderLayer within the first Zipformer2Encoder.
The tensor is of shape (num_layers_1, 1, embed_dim_1, left_cache_len_1).
- The updated cached left context tensor for the second convolution module of each
Zipformer2EncoderLayer within the first Zipformer2Encoder.
The tensor is of shape (num_layers_1, 1, embed_dim_1, left_cache_len_1).
.
.
.
- The updated cached left context tensor for the second convolution module of each
Zipformer2EncoderLayer within the sixth Zipformer2Encoder.
The tensor is of shape (num_layers_6, 1, embed_dim_6, left_cache_len_6).
- The updated total processed length tensor after subsampling of shape (1,).
"""
# pylint: disable=too-many-arguments,too-many-locals
x, new_left_cached_subsample_frames = self.subsampling(x, left_cached_subsample_frames)
batch_size, seq_len, _ = x.size()
src_key_padding_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=x.device)
# processed_mask is used to mask out the initial self.states if left_context_frames == 6,
# then tensor will contain [5, 4, 3, 2, 1, 0] as if reversed
# torch.arange(left_context_frames).
processed_mask = torch.arange(
self.left_context_frames - 1, -1, -1, dtype=torch.int32, device=x.device,
).expand(batch_size, self.left_context_frames)
# (1, left_context_size) i.e. (batch_size, left_context_size)
processed_mask = processed_mask >= processed_len.expand(processed_mask.size())
# Update processed lengths
new_processed_len = processed_len + seq_len
# (1, left_context_size + chunk_size)
src_key_padding_mask = torch.cat((processed_mask, src_key_padding_mask), dim=1)
# If the last encoder 'x' has the largest dimension, then the 'output' will be just this
# last 'x' unchanged. Otherwise it will be concatenated from different pieces of 'x',
# taking each output channel dimension from the most recent x that has it present.
output = torch.empty(
batch_size, seq_len, self.projection_dim, dtype=torch.float32, device=x.device,
)
# We have a number of Zipformer2Encoder stacks fixed and equal to 6 for any Ziformer2 size
# including small, medium and large. For the sake of smoother model TorchScript export we
# engage sequential explicit forward call of each Zipformer2Encoder module instead of using
# torch.nn.ModuleList.
# Encoder 1
(
x,
new_left_cached_keys_encoder_1,
new_left_cached_nonlin_attentions_encoder_1,
new_left_cached_values_1_encoder_1,
new_left_cached_values_2_encoder_1,
new_left_cached_convolutions_1_encoder_1,
new_left_cached_convolutions_2_encoder_1,
) = self.encoder_1(
x,
left_cached_keys_encoder_1,
left_cached_nonlin_attentions_encoder_1,
left_cached_values_1_encoder_1,
left_cached_values_2_encoder_1,
left_cached_convolutions_1_encoder_1,
left_cached_convolutions_2_encoder_1,
src_key_padding_mask[:, ::self.downsampling_factors[0]],
)
output[:, :, :x.size(2)] = x
# Encoder 2
pad = torch.zeros(
x.size(0), x.size(1), self.encoder_dims[1] - x.size(2),
dtype=torch.float32,
device=x.device,
)
x = torch.cat((x, pad), dim=2)
(
x,
new_left_cached_keys_encoder_2,
new_left_cached_nonlin_attentions_encoder_2,
new_left_cached_values_1_encoder_2,
new_left_cached_values_2_encoder_2,
new_left_cached_convolutions_1_encoder_2,
new_left_cached_convolutions_2_encoder_2,
) = self.encoder_2(
x,
left_cached_keys_encoder_2,
left_cached_nonlin_attentions_encoder_2,
left_cached_values_1_encoder_2,
left_cached_values_2_encoder_2,
left_cached_convolutions_1_encoder_2,
left_cached_convolutions_2_encoder_2,
src_key_padding_mask[:, ::self.downsampling_factors[1]],
)
output[:, :, :x.size(2)] = x
# Encoder 3
pad = torch.zeros(
x.size(0), x.size(1), self.encoder_dims[2] - x.size(2),
dtype=torch.float32,
device=x.device,
)
x = torch.cat((x, pad), dim=2)
(
x,
new_left_cached_keys_encoder_3,
new_left_cached_nonlin_attentions_encoder_3,
new_left_cached_values_1_encoder_3,
new_left_cached_values_2_encoder_3,
new_left_cached_convolutions_1_encoder_3,
new_left_cached_convolutions_2_encoder_3,
) = self.encoder_3(
x,
left_cached_keys_encoder_3,
left_cached_nonlin_attentions_encoder_3,
left_cached_values_1_encoder_3,
left_cached_values_2_encoder_3,
left_cached_convolutions_1_encoder_3,
left_cached_convolutions_2_encoder_3,
src_key_padding_mask[:, ::self.downsampling_factors[2]],
)
output[:, :, :x.size(2)] = x
# Encoder 4
pad = torch.zeros(
x.size(0), x.size(1), self.encoder_dims[3] - x.size(2),
dtype=torch.float32,
device=x.device,
)
x = torch.cat((x, pad), dim=2)
(
x,
new_left_cached_keys_encoder_4,
new_left_cached_nonlin_attentions_encoder_4,
new_left_cached_values_1_encoder_4,
new_left_cached_values_2_encoder_4,
new_left_cached_convolutions_1_encoder_4,
new_left_cached_convolutions_2_encoder_4,
) = self.encoder_4(
x,
left_cached_keys_encoder_4,
left_cached_nonlin_attentions_encoder_4,
left_cached_values_1_encoder_4,
left_cached_values_2_encoder_4,
left_cached_convolutions_1_encoder_4,
left_cached_convolutions_2_encoder_4,
src_key_padding_mask[:, ::self.downsampling_factors[3]],
)
output[:, :, :x.size(2)] = x
# Encoder 5
x = x[:, :, :self.encoder_dims[4]]
(
x,
new_left_cached_keys_encoder_5,
new_left_cached_nonlin_attentions_encoder_5,
new_left_cached_values_1_encoder_5,
new_left_cached_values_2_encoder_5,
new_left_cached_convolutions_1_encoder_5,
new_left_cached_convolutions_2_encoder_5,
) = self.encoder_5(
x,
left_cached_keys_encoder_5,
left_cached_nonlin_attentions_encoder_5,
left_cached_values_1_encoder_5,
left_cached_values_2_encoder_5,
left_cached_convolutions_1_encoder_5,
left_cached_convolutions_2_encoder_5,
src_key_padding_mask[:, ::self.downsampling_factors[4]],
)
output[:, :, :x.size(2)] = x
# Encoder 6
x = x[:, :, :self.encoder_dims[5]]
(
x,
new_left_cached_keys_encoder_6,
new_left_cached_nonlin_attentions_encoder_6,
new_left_cached_values_1_encoder_6,
new_left_cached_values_2_encoder_6,
new_left_cached_convolutions_1_encoder_6,
new_left_cached_convolutions_2_encoder_6,
) = self.encoder_6(
x,
left_cached_keys_encoder_6,
left_cached_nonlin_attentions_encoder_6,
left_cached_values_1_encoder_6,
left_cached_values_2_encoder_6,
left_cached_convolutions_1_encoder_6,
left_cached_convolutions_2_encoder_6,
src_key_padding_mask[:, ::self.downsampling_factors[5]],
)
output[:, :, :x.size(2)] = x
output = self.downsample_output(output)
output = self.projection_output(output)
if self.ctc:
output = torch.nn.functional.log_softmax(output, dim=2)
return (
output,
# Because of the reasons mentioned in previous comments,
# for the sake of easier TorchScript and ONNX export we
# preserve the explicit listing of each left cache tensor.
new_left_cached_subsample_frames,
new_left_cached_keys_encoder_1,
new_left_cached_nonlin_attentions_encoder_1,
new_left_cached_values_1_encoder_1,
new_left_cached_values_2_encoder_1,
new_left_cached_convolutions_1_encoder_1,
new_left_cached_convolutions_2_encoder_1,
new_left_cached_keys_encoder_2,
new_left_cached_nonlin_attentions_encoder_2,
new_left_cached_values_1_encoder_2,
new_left_cached_values_2_encoder_2,
new_left_cached_convolutions_1_encoder_2,
new_left_cached_convolutions_2_encoder_2,
new_left_cached_keys_encoder_3,
new_left_cached_nonlin_attentions_encoder_3,
new_left_cached_values_1_encoder_3,
new_left_cached_values_2_encoder_3,
new_left_cached_convolutions_1_encoder_3,
new_left_cached_convolutions_2_encoder_3,
new_left_cached_keys_encoder_4,
new_left_cached_nonlin_attentions_encoder_4,
new_left_cached_values_1_encoder_4,
new_left_cached_values_2_encoder_4,
new_left_cached_convolutions_1_encoder_4,
new_left_cached_convolutions_2_encoder_4,
new_left_cached_keys_encoder_5,
new_left_cached_nonlin_attentions_encoder_5,
new_left_cached_values_1_encoder_5,
new_left_cached_values_2_encoder_5,
new_left_cached_convolutions_1_encoder_5,
new_left_cached_convolutions_2_encoder_5,
new_left_cached_keys_encoder_6,
new_left_cached_nonlin_attentions_encoder_6,
new_left_cached_values_1_encoder_6,
new_left_cached_values_2_encoder_6,
new_left_cached_convolutions_1_encoder_6,
new_left_cached_convolutions_2_encoder_6,
new_processed_len,
)
def get_init_states(
input_dim: int,
num_encoder_layers: list[int],
downsample_left_pad_frames: list[int],
encoder_dims: list[int],
query_dims: list[int],
value_dims: list[int],
head_dims: list[int],
convolution_left_pad_frames: list[int],
device: torch.device,
) -> list[torch.Tensor]:
"""
Get initial states for the Zipformer2 encoder. The method generates a list of torch tensors,
where the first tensor corresponds to a subsampling module left cache. Next, for each
Zipformer2Encoder module we add six cache tensors that are essential for multi-head attention
and convolution modules. Finally, at the end we append a total processed frames tensor,
initialized with zero.
Parameters
----------
input_dim : int
The number of input features.
num_encoder_layers : list[int]
The number of Zipformer2EncoderLayer modules for each Zipformer2Encoder stack.
downsample_left_pad_frames : list[int]
The multi-head attention left context cache frames after downsampling.
encoder_dims : list[int]
The embedding dimension for each Zipformer2Encoder stack.
query_dims : list[int]
The multi-head attention query dimension for each Zipformer2Encoder stack.
value_dims : list[int]
The multi-head attention value dimension for each Zipformer2Encoder stack.
head_dims : list[int]
The non-linear attention head dimension for each Zipformer2Encoder stack.
convolution_left_pad_frames : list[int]
The convolution modules left padding number of frames for each Zipformer2Encoder stack.
device : torch.device
The device used to store cache tensors. Should be
either torch.device("cpu") or torch.device("cuda").
Returns
-------
list[torch.Tensor[torch.float32 | torch.int32]]
A list of left cache tensors.
- A subsampling module left cache tensor of shape (1, 10, input_dim)
- The first Zipformer2Encoder cached attention key tensor of the left context in each
Zipformer2EncoderLayer of the stack.
The tensor is of shape (num_layers_1, 1, left_context_len_1, query_dim_1).
- The first Zipformer2Encoder left context cached attention tensor for the non-linear
attention module in each Zipformer2EncoderLayer of the stack.
The tensor is of shape (num_layers_1, 1, left_context_len_1, head_dim_1).
- The first Zipformer2Encoder cached left context tensor for the first self-attention
module in each Zipformer2EncoderLayer of the stack.
The tensor is of shape (num_layers_1, 1, left_context_len_1, value_dim_1).
- The first Zipformer2Encoder cached left context tensor for the second self-attention
module in each Zipformer2EncoderLayer of the stack.
The tensor is of shape (num_layers_1, 1, left_context_len_1, value_dim_1).
- The first Zipformer2Encoder cached left context tensor for the first convolution module
in each Zipformer2EncoderLayer of the stack.
The tensor is of shape (num_layers_1, 1, encoder_dim_1, conv_left_pad_1).
- The first Zipformer2Encoder cached left context tensor for the second convolution module
in each Zipformer2EncoderLayer of the stack.
The tensor is of shape (num_layers_1, 1, encoder_dim_1, conv_left_pad_1).
.
.
.
- The sixth Zipformer2Encoder cached left context tensor for the second convolution module
in each Zipformer2EncoderLayer of the stack.
The tensor is of shape (num_layers_6, 1, encoder_dim_6, conv_left_pad_6).
- The processed length integer tensor initialized with a single zero element.
The tensor is of shape (1,).
"""
# pylint: disable=too-many-locals
if not (
len(num_encoder_layers)
== len(downsample_left_pad_frames)
== len(encoder_dims)
== len(query_dims)
== len(value_dims)
== len(head_dims)
== len(convolution_left_pad_frames)
):
raise ValueError(
'It is required that all encoder parameter lists have the same '
'length, but got following parameter list lengths:\n'
f'len(num_encoder_layers) == {len(num_encoder_layers)}\n'
f'len(downsample_left_pad_frames) == {len(downsample_left_pad_frames)}\n'
f'len(encoder_dims) == {len(encoder_dims)}\n'
f'len(query_dims) == {len(query_dims)}\n'
f'len(value_dims) == {len(value_dims)}\n'
f'len(nonlin_attn_head_dims) == {len(head_dims)}\n'
f'len(convolution_left_pad_frames) == {len(convolution_left_pad_frames)}.',
)
states = [subsampling_get_init_states(input_dim, device)]
for i, num_layers in enumerate(num_encoder_layers):
encoder_dim = encoder_dims[i]
query_dim = query_dims[i]
value_dim = value_dims[i]
head_dim = head_dims[i]
left_context_len = downsample_left_pad_frames[i]
left_cache_len = convolution_left_pad_frames[i]
# batch size is 1
states += [
torch.zeros(
num_layers, 1, left_context_len, query_dim, dtype=torch.float32, device=device,
),
torch.zeros(
num_layers, 1, left_context_len, head_dim, dtype=torch.float32, device=device,
),
torch.zeros(
num_layers, 1, left_context_len, value_dim, dtype=torch.float32, device=device,
),
torch.zeros(
num_layers, 1, left_context_len, value_dim, dtype=torch.float32, device=device,
),
torch.zeros(
num_layers, 1, encoder_dim, left_cache_len, dtype=torch.float32, device=device,
),
torch.zeros(
num_layers, 1, encoder_dim, left_cache_len, dtype=torch.float32, device=device,
),
]
states.append(torch.zeros(1, dtype=torch.int32, device=device))
return states
class Zipformer2EncoderLayer(torch.nn.Module):
"""
Zipformer2EncoderLayer module, the basic block of Zipformer2Encoder encoder stack.
"""
# pylint: disable=too-many-instance-attributes
def __init__(
self,
embed_dim: int,
pos_dim: int,
num_heads: int,
query_head_dim: int,
pos_head_dim: int,
value_head_dim: int,
feedforward_dim: int,
cnn_module_kernel: int,
left_context_len: int,
right_context_len: int,
device: torch.device,
) -> None:
"""
Zipformer2EncoderLayer initialization.
Parameters
----------
embed_dim : int
The input and output embedding dimension. The number of channels is the same for input
and output of this module.
pos_dim : int
The dimension of the relative positional embedding.
num_heads : int
The number of heads for attention weights and self-attention.
query_head_dim : int
The dimension of the query and key per attention head in attention weights.
pos_head_dim: int
The dimension of the projected positional encoding
per attention head in attention weights.
value_head_dim : int
The dimension of the value per attention head in self-attention.
feedforward_dim : int
The hidden dimension of the feedforward modules.
cnn_module_kernel : int
The kernel size of the convolution modules.
left_context_len : int
The module left context number of subsampled frames.
right_context_len : int
The module right context number of subsampled frames.
Used to update attention and convolution left caches.
device : torch.device
The device used to store the layer weights. Should be
either torch.device("cpu") or torch.device("cuda").
"""
# pylint: disable=too-many-arguments
super().__init__()
self.left_context_len = left_context_len
# self.bypass implements the whole layer skipping.
self.bypass = BypassModule(embed_dim, device)
# bypass_mid is bypass used in the middle of the layer.
self.bypass_mid = BypassModule(embed_dim, device)
self.self_attn_weights = RelPositionMultiheadAttentionWeights(
embed_dim, pos_dim, num_heads, query_head_dim, pos_head_dim, right_context_len, device,
)
self.self_attn1 = SelfAttention(
embed_dim, num_heads, value_head_dim, right_context_len, device,
)
self.self_attn2 = SelfAttention(
embed_dim, num_heads, value_head_dim, right_context_len, device,
)
self.nonlin_attention = NonlinAttention(
embed_dim, 3 * embed_dim // 4, right_context_len, device,
)
self.feed_forward1 = FeedforwardModule(embed_dim, (feedforward_dim * 3) // 4, device)
self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, device)
self.feed_forward3 = FeedforwardModule(embed_dim, (feedforward_dim * 5) // 4, device)
self.conv_module1 = ConvolutionModule(
embed_dim, cnn_module_kernel, right_context_len, device,
)
self.conv_module2 = ConvolutionModule(
embed_dim, cnn_module_kernel, right_context_len, device,
)
self.norm = BiasNorm(embed_dim, device)
def forward(
self,
src: torch.Tensor,
pos_emb: torch.Tensor,
left_cached_key: torch.Tensor,
left_cached_nonlin_attn: torch.Tensor,
left_cached_val_1: torch.Tensor,
left_cached_val_2: torch.Tensor,
left_cached_conv_1: torch.Tensor,
left_cached_conv_2: torch.Tensor,
src_key_padding_mask: torch.Tensor,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
"""
Does a forward pass of the Zipformer2EncoderLayer module. Returns an output tensor with the
same shape as input, and updated left caches for multiple attention and convolution
mudules.
Parameters
----------
src : torch.Tensor[torch.float32]
The input float tensor of shape (1, seq_len, embed_dim). The module input.
pos_emb : torch.Tensor[torch.float32]
A positional embedding tensor
of shape (1, left_context_len + 2 * seq_len - 1, pos_dim).
left_cached_key : torch.Tensor[torch.float32]
A cached attention key tensor of the left context
of shape (1, left_context_len, query_dim).
left_cached_nonlin_attn : torch.Tensor[torch.float32]
A left context cached attention tensor for the non-linear attention module
of shape (1, left_context_len, head_dim).
left_cached_val_1 : torch.Tensor[torch.float32]
A cached left context tensor for the first self-attention module
of shape (1, left_context_len, value_dim).
left_cached_val_2 : torch.Tensor[torch.float32]
A cached left context for the second self-attention module
of shape (1, left_context_len, value_dim).
left_cached_conv_1 : torch.Tensor[torch.float32]
A cached left context tensor for the first convolution module
of shape (1, embed_dim, left_cache_len).
left_cached_conv_2 : torch.Tensor[torch.float32]
A cached left context tensor for the second convolution module
of shape (1, embed_dim, left_cache_len).
src_key_padding_mask : torch.Tensor[torch.bool]
A boolean tensor of shape (1, seq_len_2). Positions that are True in this mask will be
ignored as sources in the attention weighting and convolution modules.
Returns
-------
tuple[
torch.Tensor[torch.float32],
torch.Tensor[torch.float32],
torch.Tensor[torch.float32],
torch.Tensor[torch.float32],
torch.Tensor[torch.float32],
torch.Tensor[torch.float32],
torch.Tensor[torch.float32],
]
A tuple of seven float tensors:
- The module output of shape (1, seq_len, embed_dim).
A tensor with the same shape as input.
- The updated left context cached attention key tensor
of shape (1, left_context_len, query_dim).
- The updated left context cached attention tensor for the non-linear attention module
of shape (1, left_context_len, head_dim).
- The updated cached left context for the first self-attention module
of shape (1, left_context_len, value_dim).
- The updated cached left context for the second self-attention module
of shape (1, left_context_len, value_dim).
- The updated cached left context for the first convolution module
of shape (1, embed_dim, left_cache_len).
- The updated cached left context for the second convolution module
of shape (1, embed_dim, left_cache_len).
"""
src_orig = src
# attn_weights: (1, num_heads, seq_len, seq_len_2)
attn_weights, left_cached_key = self.self_attn_weights(
src, pos_emb, left_cached_key, src_key_padding_mask,
)
src = src + self.feed_forward1(src)
na, left_cached_nonlin_attn = self.nonlin_attention(
src, attn_weights[:, 0], left_cached_nonlin_attn,
)
src = src + na
self_attn, left_cached_val_1 = self.self_attn1(src, attn_weights, left_cached_val_1)
src = src + self_attn
src_conv, left_cached_conv_1 = self.conv_module1(
src, left_cached_conv_1, src_key_padding_mask[:, self.left_context_len:],
)
src = src + src_conv
src = src + self.feed_forward2(src)
# bypass in the middle of the layer.
src = self.bypass_mid(src_orig, src)
self_attn, left_cached_val_2 = self.self_attn2(src, attn_weights, left_cached_val_2)
src = src + self_attn
src_conv, left_cached_conv_2 = self.conv_module2(
src, left_cached_conv_2, src_key_padding_mask[:, self.left_context_len:],
)
src = src + src_conv
src = src + self.feed_forward3(src)
src = self.norm(src)
src = self.bypass(src_orig, src)
return (
src,
left_cached_key,
left_cached_nonlin_attn,
left_cached_val_1,
left_cached_val_2,
left_cached_conv_1,
left_cached_conv_2,
)
class Zipformer2Encoder(torch.nn.Module):
"""
Zipformer2Encoder is a stack of Zipformer2EncoderLayer modules.
"""
def __init__(
self,
encoder_layer: torch.nn.Module,
num_layers: int,
embed_dim: int,
pos_dim: int,
pos_max_len: int,
downsample: int,
device: torch.device,
) -> None:
"""
Zipformer2Encoder initialization.
Parameters
----------
encoder_layer : torch.nn.Module
An instance of the Zipformer2EncoderLayer class.
num_layers : int
The number of encoder Zipformer2EncoderLayer modules in the stack.
embed_dim : int
The input and output embedding dimension. The embedding dimension is the same for
input and output of this module.
pos_dim : int
The dimension for the relative positional embedding.
downsample : int
The downsampling factor of the module, the input will be downsampled in the beginning
and upsampled back at the end.
device : torch.device
The device used to store the layer weights. Should be
either torch.device("cpu") or torch.device("cuda").
"""
super().__init__()
self.num_layers = num_layers
self.downsample = SimpleDownsample(downsample, device)
self.encoder_pos = CompactRelPositionalEncoding(
pos_dim, pos_max_len, encoder_layer.left_context_len, device,
)
self.layers = torch.nn.ModuleList(
[copy.deepcopy(encoder_layer) for _ in range(num_layers)],
)
self.upsample = SimpleUpsample(downsample)
self.out_combiner = BypassModule(embed_dim, device)
def forward(
self,
src: torch.Tensor,
left_cached_keys: torch.Tensor,
left_cached_nonlin_attentions: torch.Tensor,
left_cached_values_1: torch.Tensor,
left_cached_values_2: torch.Tensor,
left_cached_convolutions_1: torch.Tensor,
left_cached_convolutions_2: torch.Tensor,
src_key_padding_mask: torch.Tensor,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
"""
Does a forward pass of the Zipformer2Encoder module. Returns an output tensor with the same
shape as input, and updated left caches for multiple attention and convolution mudules.
Parameters
----------
src : torch.Tensor[torch.float32]
The input float tensor of shape (1, seq_len, embed_dim). The module input.
left_cached_keys : torch.Tensor[torch.float32]
A cached attention key tensor of the left context for each Zipformer2EncoderLayer.
A tensor is of shape (num_layers, 1, left_context_len, query_dim).
left_cached_nonlin_attentions : torch.Tensor[torch.float32]
A left context cached attention tensor for the non-linear attention module of each
Zipformer2EncoderLayer. A tensor is
of shape (num_layers, 1, left_context_len, head_dim).
left_cached_values_1 : torch.Tensor[torch.float32]
A cached left context tensor for the first self-attention module of each
Zipformer2EncoderLayer. A tensor is
of shape (num_layers, 1, left_context_len, value_dim).
left_cached_values_2 : torch.Tensor[torch.float32]
A cached left context tensor for the second self-attention module of each
Zipformer2EncoderLayer. A tensor is
of shape (num_layers, 1, left_context_len, value_dim).
left_cached_convolutions_1 : torch.Tensor[torch.float32]
A cached left context tensor for the first convolution module of each
Zipformer2EncoderLayer. A tensor is
of shape (num_layers, 1, embed_dim, left_cache_len).
left_cached_convolutions_2 : torch.Tensor[torch.float32]
A cached left context tensor for the second convolution module of each
Zipformer2EncoderLayer. A tensor is
of shape (num_layers, 1, embed_dim, left_cache_len).
src_key_padding_mask : torch.Tensor[torch.bool]
A boolean tensor of shape (1, seq_len_2). Positions that are True in this mask will be
ignored as sources in the attention weighting and convolution modules.
Returns
-------
tuple[
torch.Tensor[torch.float32],
torch.Tensor[torch.float32],
torch.Tensor[torch.float32],
torch.Tensor[torch.float32],
torch.Tensor[torch.float32],
torch.Tensor[torch.float32],
torch.Tensor[torch.float32],
]
A tuple of seven float tensors:
- The module output of shape (1, seq_len, embed_dim).
A tensor with the same shape as input.
- The updated cached attention key tensor of the left context for each
Zipformer2EncoderLayer. A tensor is
of shape (num_layers, 1, left_context_len, query_dim).
- The updated left context cached attention tensor for the non-linear attention module
of each Zipformer2EncoderLayer. A tensor is
of shape (num_layers, 1, left_context_len, head_dim).
- The updated cached left context tensor for the first self-attention module of each
Zipformer2EncoderLayer. A tensor is
of shape (num_layers, 1, left_context_len, value_dim).
- The updated cached left context tensor for the second self-attention module of each
Zipformer2EncoderLayer. A tensor is
of shape (num_layers, 1, left_context_len, value_dim).
- The updated cached left context tensor for the first convolution module of each
Zipformer2EncoderLayer. A tensor is
of shape (num_layers, 1, embed_dim, left_cache_len).
- The updated cached left context tensor for the second convolution module of each
Zipformer2EncoderLayer. A tensor is
of shape (num_layers, 1, embed_dim, left_cache_len).
"""
# pylint: disable=too-many-locals
src_orig = src
src = self.downsample(src)
pos_emb = self.encoder_pos(src)
new_left_cached_keys = torch.empty(
left_cached_keys.shape, dtype=torch.float32, device=left_cached_keys.device,
)
new_left_cached_nonlin_attentions = torch.empty(
left_cached_nonlin_attentions.shape,
dtype=torch.float32,
device=left_cached_nonlin_attentions.device,
)
new_left_cached_values_1 = torch.empty(
left_cached_values_1.shape, dtype=torch.float32, device=left_cached_values_1.device,
)
new_left_cached_values_2 = torch.empty(
left_cached_values_2.shape, dtype=torch.float32, device=left_cached_values_2.device,
)
new_left_cached_convolutions_1 = torch.empty(
left_cached_convolutions_1.shape,
dtype=torch.float32,
device=left_cached_convolutions_1.device,
)
new_left_cached_convolutions_2 = torch.empty(
left_cached_convolutions_2.shape,
dtype=torch.float32,
device=left_cached_convolutions_2.device,
)
for i, mod in enumerate(self.layers):
(
src,
new_left_cached_keys[i],
new_left_cached_nonlin_attentions[i],
new_left_cached_values_1[i],
new_left_cached_values_2[i],
new_left_cached_convolutions_1[i],
new_left_cached_convolutions_2[i],
) = mod(
src,
pos_emb,
left_cached_keys[i],
left_cached_nonlin_attentions[i],
left_cached_values_1[i],
left_cached_values_2[i],
left_cached_convolutions_1[i],
left_cached_convolutions_2[i],
src_key_padding_mask,
)
src = self.upsample(src)
# Remove any extra frames that are not a multiple of downsample_factor
src = src[:, : src_orig.size(1)]
src = self.out_combiner(src_orig, src)
return (
src,
new_left_cached_keys,
new_left_cached_nonlin_attentions,
new_left_cached_values_1,
new_left_cached_values_2,
new_left_cached_convolutions_1,
new_left_cached_convolutions_2,
)
class BypassModule(torch.nn.Module):
"""
A bypass module that implements a learnable bypass scale for each input channel.
"""
def __init__(self, num_channels: int, device: torch.device) -> None:
"""
BypassModule initialization.
Parameters
----------
num_channels : int
The number of input channels, corresponds to the number of learnable bypass scales.
device : torch.device
The device used to store the layer weights. Should be
either torch.device("cpu") or torch.device("cuda").
"""
super().__init__()
self.bypass_scale = torch.nn.Parameter(
torch.ones(num_channels, dtype=torch.float32, device=device),
)
def forward(self, x_early: torch.Tensor, x_later: torch.Tensor) -> torch.Tensor:
"""
Does a forward pass of the BypassModule module.
Parameters
----------
x_early : torch.Tensor[torch.float32]
The input float tensor of shape (1, seq_len, num_channels).
The module input that will be propagated with (1 - self.bypass_scale) weight.
x_later : torch.Tensor[torch.float32]
An input float tensor of shape (1, seq_len, num_channels).
The module input that will be propagated with self.bypass_scale weight.
Returns
-------
torch.Tensor[torch.float32]
A float tensor of shape (1, seq_len, num_channels). The shape is the same for x_early
and x_later. The output of the module is x_early bypassed and added to x_later.
"""
# It's just a slightly more efficient implementation of
# (1.0 - self.bypass_scale) * x_early + self.bypass_scale * x_later
return x_early + (x_later - x_early) * self.bypass_scale
class SimpleDownsample(torch.nn.Module):
"""
A downsample layer, does downsampling by weighted sum aggregation.
"""
def __init__(self, downsample: int, device: torch.device) -> None:
"""
SimpleDownsample initialization.
Parameters
----------
downsample : int
The module downsampling factor.
device : torch.device
The device used to store the layer weights.
Either torch.device("cpu") or torch.device("cuda").
"""
super().__init__()
self.weights = torch.nn.Parameter(
torch.zeros(downsample, 1, dtype=torch.float32, device=device),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Does a forward pass of the SimpleDownsample module.
Parameters
----------
x : torch.Tensor[torch.float32]
The input float tensor of shape (1, seq_len, num_channels).
The module input that will be downsampled.
Returns
-------
torch.Tensor[torch.float32]
A float tensor of shape
(1, (seq_len + downsample - 1) // downsample, num_channels).
The downsampled output of the module.
"""
downsample = self.weights.size(0)
if downsample == 1:
return x
batch_size, seq_len, in_channels = x.size() # batch_size is 1
downsampled_seq_len = (seq_len + downsample - 1) // downsample
# Pad to an exact multiple of downsample. Right-pad x, repeating the last element.
pad_frames = downsampled_seq_len * downsample - seq_len
if pad_frames > 0:
pad = x[:, seq_len - 1:, :].expand(batch_size, pad_frames, in_channels)
x = torch.cat((x, pad), dim=1)
# (1, seq_len, in_channels) -> (1, seq_len // downsample, downsample, in_channels)
x = x.reshape(batch_size, downsampled_seq_len, downsample, in_channels)
x = torch.sum(x * self.weights, dim=2)
return x
class SimpleUpsample(torch.nn.Module):
"""
An upsample layer, does upsampling by repeating the input frames.
"""
def __init__(self, upsample: int) -> None:
"""
SimpleUpsample initialization.
Parameters
----------
upsample : int
The module upsampling factor.
"""
super().__init__()
self.upsample = upsample
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Does a forward pass of the SimpleUpsample module.
Parameters
----------
x : torch.Tensor[torch.float32]
The input float tensor of shape (1, seq_len, num_channels).
The module input that will be upsampled.
Returns
-------
torch.Tensor[torch.float32]
A float tensor of shape (1, seq_len * upsample, num_channels).
The upsampled output of the module.
"""
if self.upsample == 1:
return x
x = torch.repeat_interleave(x, self.upsample, dim=1)
return x
class CompactRelPositionalEncoding(torch.nn.Module):
"""
Relative positional encoding module. This version is "compact" meaning it is able to encode the
important information about the relative positions in a relatively small number of dimensions.
The goal is to make it so that small differences between large relative offsets
(e.g. 1000 vs. 1001) make very little difference to the embedding. Such differences were
potentially important when encoding absolute position, but not important when encoding relative
position because there is now no need to compare two large offsets with each other.
This implementation works by projecting the interval [-infinity, infinity] to a finite interval
using the torch.atan() function before doing the fourier transform of that fixed interval.
The torch.atan() function would compress the "long tails" too small, making it hard to
distinguish between different magnitudes of large offsets. To mitigate this a logarithmic
function is used to compress large offsets to a smaller range before applying torch.atan().
Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets
as long as they are quite close to the origin, e.g. abs(offset) <= sqrt(embedding_dim).
"""
def __init__(
self, embed_dim: int, max_length: int, left_context_len: int, device: torch.device,
) -> None:
"""
CompactRelPositionalEncoding initialization.
Parameters
----------
embed_dim : int
The positional embedding dimension.
max_length : int
The maximum length of the input that this module will be able to handle after
initialization without positional embeddings expansion. In case of longer input the
positional embeddings will be re-computed to adjust bigger length.
left_context_len : int
Length of cached left context.
device : torch.device
The device used to store the layer positional embeddings.
Should be either torch.device("cpu") or torch.device("cuda").
"""
super().__init__()
if embed_dim % 2 != 0:
raise ValueError(
'Embedding dimension for CompactRelPositionalEncoding '
f'should be an even number, but got {embed_dim}.',
)
self.embed_dim = embed_dim
self.left_context_len = left_context_len
self.pos_emb = self.create_pos_emb(max_length, device)
def create_pos_emb(self, max_length: int, device: torch.device) -> torch.Tensor:
"""
Creates a relative positional embeddings based on the maximum length.
This method is used to create positional embeddings with a
sufficiently long temporal axes during module initialization.
We want it to be big enough to avoid getting input x that is longer
than self.pos_emb during inference. On the other hand, we want
to initialize it with the smallest maximum length possible to consume
less memory.
Parameters
----------
max_length : int
The maximum length of the input that can be handeled by this layer. Increasing this
will let to process bigger input (speaking of temporal dimension), but will also
increase the memory consumption.
device : torch.device
The device used to store the positional embeddings.
Should be either torch.device("cpu") or torch.device("cuda").
Returns
-------
torch.Tensor[torch.float32]
A float tensor of shape (2 * max_length - 1, embed_dim).
Relative positional embeddings.
"""
# if max_length == 4, the x would contain [-3, -2, -1, 0, 1, 2, 3]
x = torch.arange(-max_length + 1, max_length, dtype=torch.float32, device=device)
# Compression length is an arbitrary heuristic, if it is larger we have more resolution for
# small time offsets but less resolution for large time offsets.
compression_length = self.embed_dim**0.5
# Compressing x within the next line of code, similarly to uncompressed x, it goes from
# -infinity to infinity as the sequence length goes from -infinity to infinity, but it does
# so more slowly than sequence length for the large absolute values of sequence length.
# The formula is chosen so that d(x_compressed) / dx is equal to 1 around x == 0,
# which is important.
x = compression_length * torch.sign(x) * torch.log(torch.abs(x) / compression_length + 1.0)
# results between -pi and pi
x = torch.atan(2.0 * torch.pi * x / self.embed_dim)
freqs = torch.arange(1, self.embed_dim // 2 + 1, dtype=torch.float32, device=device)
x = x.unsqueeze(1) * freqs
pos_emb = torch.zeros(x.size(0), self.embed_dim, dtype=torch.float32, device=device)
pos_emb[:, 0::2] = torch.cos(x)
pos_emb[:, 1::2] = torch.sin(x)
pos_emb[:, self.embed_dim - 1] = 1.0 # for bias.
return pos_emb
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Does a forward pass of the CompactRelPositionalEncoding module.
Returns a relative positional embeddings based on the input x temporal dimension.
Parameters
----------
x : torch.Tensor[torch.float32]
An input float tensor of shape (1, seq_len, embed_dim).
The module input. It's shape will be used to construct relative positional embeddings.
Returns
-------
torch.Tensor[torch.float32]
A float tensor of shape (1, self.left_context_len + 2 * seq_len - 1, embed_dim).
Relative positional embeddings.
"""
if self.pos_emb.size(0) < 2 * (x.size(1) + self.left_context_len) - 1:
self.pos_emb = self.create_pos_emb(x.size(1) + self.left_context_len, x.device)
# Length of negative side: x.size(1) + self.left_context_len.
# Length of positive side: x.size(1).
pos_emb = self.pos_emb[
self.pos_emb.size(0) // 2 - x.size(1) - self.left_context_len + 1:
self.pos_emb.size(0) // 2 + x.size(1)
].unsqueeze(0).repeat(x.size(0), 1, 1)
# (1, left_context_len + 2 * seq_len - 1, embed_dim),
# i. e. (batch_size, pos_len, embed_dim).
return pos_emb
class RelPositionMultiheadAttentionWeights(torch.nn.Module):
"""
Module that computes multi-head attention weights with relative position encoding.
Various other modules consume the resulting attention weights: see, for example,
the SelfAttention module which allows you to compute conventional self-attention.
This is a quite heavily modified from:
"Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context".
"""
def __init__(
self,
embed_dim: int,
pos_dim: int,
num_heads: int,
query_head_dim: int,
pos_head_dim: int,
right_context: int,
device: torch.device,
) -> None:
"""
RelPositionMultiheadAttentionWeights initialization.
Parameters
----------
embed_dim : int
The embedding dimension. The number of channels at the input to this module.
pos_dim : int
A dimension of the positional embeddings.
num_heads : int
The number of attention heads to compute weights.
query_head_dim : int
The dimension of the query and key per head.
pos_head_dim : int
The dimension of the projected positional encoding per head.
right_context : int
The module look ahead future context, used to update left
cached attention key correctly.
device : torch.device
The device used to store the layer positional embeddings. Should be
either torch.device("cpu") or torch.device("cuda").
"""
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.query_head_dim = query_head_dim
self.pos_head_dim = pos_head_dim
self.right_context = right_context
in_proj_dim = (2 * query_head_dim + pos_head_dim) * num_heads
self.in_proj = torch.nn.Linear(embed_dim, in_proj_dim, device=device)
# Linear transformation for positional encoding.
self.linear_pos = torch.nn.Linear(
pos_dim, num_heads * pos_head_dim, bias=False, device=device,
)
def forward(
self, x: torch.Tensor,
pos_emb: torch.Tensor,
left_cached_key: torch.Tensor,
key_padding_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Does a forward pass of the RelPositionMultiheadAttentionWeights module.
Returns attention weights and updated cached attention key tensor of the left context.
Parameters
----------
x : torch.Tensor[torch.float32]
The input float tensor of shape (1, seq_len, embed_dim). The module input.
pos_emb : torch.Tensor[torch.float32]
A positional embedding tensor
of shape (1, left_context_len + 2 * seq_len - 1, pos_dim).
left_cached_key : torch.Tensor[torch.float32]
A cached attention key tensor of the left context
of shape (1, left_context_len, key_dim).
key_padding_mask : torch.Tensor[torch.bool]
A boolean tensor of shape (1, seq_len_2). Positions that are True in this mask will be
ignored as sources in the attention weighting.
Returns
-------
tuple[torch.Tensor[torch.float32], torch.Tensor[torch.float32]]
A tuple of two float tensors:
- attention weights, of shape (1, hum_heads, seq_len, seq_len_2)
interpreted as (1, hum_heads, tgt_seq_len, src_seq_len).
- updated cached attention key tensor of the left context
of shape (1, left_context_len, key_dim).
"""
# pylint: disable=too-many-locals
batch_size = x.size(0) # batch size is 1
seq_len = x.size(1)
x = self.in_proj(x)
query_dim = self.query_head_dim * self.num_heads
# Self-attention.
q = x[:, :, :query_dim]
k = x[:, :, query_dim: 2 * query_dim]
# p is the position-encoding query.
p = x[:, :, 2 * query_dim:]
# Pad key with cached left context.
k = torch.cat((left_cached_key, k), dim=1)
# Update cached left contexts
seq_len_2 = k.size(1) # left_context_len + seq_len
left_cached_key = k[
:,
seq_len_2 - self.right_context - left_cached_key.size(1):
seq_len_2 - self.right_context,
]
q = q.reshape(batch_size, seq_len, self.num_heads, self.query_head_dim)
p = p.reshape(batch_size, seq_len, self.num_heads, self.pos_head_dim)
k = k.reshape(batch_size, seq_len_2, self.num_heads, self.query_head_dim)
# seq_len refers to target, seq_len_2 refers to source.
q = q.permute(0, 2, 1, 3) # (1, hum_heads, seq_len, query_head_dim)
p = p.permute(0, 2, 1, 3) # (1, hum_heads, seq_len, pos_head_dim)
k = k.permute(0, 2, 3, 1) # (1, hum_heads, key_head_dim, seq_len_2)
attn_scores = torch.matmul(q, k) # (1, hum_heads, seq_len, seq_len_2)
pos_len = pos_emb.size(1) # left_context_len + 2 * seq_len - 1
# (1, pos_len, num_heads * pos_head_dim)
pos_emb = self.linear_pos(pos_emb)
pos_emb = pos_emb.reshape(
batch_size, pos_len, self.num_heads, self.pos_head_dim,
).permute(0, 2, 3, 1) # (1, hum_heads, pos_head_dim, pos_len)
# (1, hum_heads, seq_len, pos_head_dim) x (1, hum_heads, pos_head_dim, pos_len) ->
# -> (1, hum_heads, seq_len, pos_len) where pos_len represents relative position.
pos_scores = torch.matmul(p, pos_emb)
# Now we need to perform the relative shift of the pos_scores, to do that we need to add
# a column of zeros to the left side of the last dimension and perform the relative shift.
pos_scores_pad = torch.zeros(
pos_scores.size(0), pos_scores.size(1), pos_scores.size(2), 1,
dtype=torch.float32,
device=pos_scores.device,
)
# (1, hum_heads, seq_len, pos_len + 1)
pos_scores = torch.cat((pos_scores_pad, pos_scores), dim=3)
pos_scores = pos_scores.reshape(
batch_size, self.num_heads, pos_len + 1, seq_len,
) # (1, hum_heads, pos_len + 1, seq_len)
# Now drop the extra row that had been added over padding and reshape.
pos_scores = pos_scores[:, :, 1:].reshape(
batch_size, self.num_heads, seq_len, pos_len,
) # (1, hum_heads, seq_len, pos_len)
# (1, hum_heads, seq_len, seq_len_2)
attn_scores = attn_scores + pos_scores[:, :, :, : attn_scores.size(3)]
# (1, seq_len_2) -> (1, 1, 1, seq_len_2) to make it broadcastable to attn_scores shape.
key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(2)
attn_scores = torch.masked_fill(attn_scores, key_padding_mask, -1000.0)
attn_weights = torch.softmax(attn_scores, dim=3)
return attn_weights, left_cached_key
class SelfAttention(torch.nn.Module):
"""
The simplest possible attention module. This one works with pre-computed attention weights,
e.g. as computed by RelPositionMultiheadAttentionWeights.
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
value_head_dim: int,
right_context: int,
device: torch.device,
) -> None:
"""
SelfAttention initialization.
Parameters
----------
embed_dim : int
The input and output embedding dimension. The number of channels is the same for input
and output of this module.
num_heads : int
The number of attention heads.
value_head_dim : int
The dimension of the value per head.
right_context : int
The module look ahead future context, used to update left cached
attention value correctly.
device : torch.device
The device used to store the layer positional embeddings.
Either torch.device("cpu") or torch.device("cuda").
"""
super().__init__()
self.in_proj = torch.nn.Linear(embed_dim, num_heads * value_head_dim, device=device)
self.out_proj = torch.nn.Linear(num_heads * value_head_dim, embed_dim, device=device)
self.right_context = right_context
def forward(
self, x: torch.Tensor, attn_weights: torch.Tensor, left_cached_val: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Does a forward pass of the SelfAttention module. Returns attention weighted input tensor
and updated cached attention value tensor of the left context.
Parameters
----------
x : torch.Tensor[torch.float32]
The input float tensor of shape (1, seq_len, embed_dim). The module input.
attn_weights : torch.Tensor[torch.float32]
The tensor of shape (1, num_heads, seq_len, seq_len_2), with (seq_len, seq_len_2)
being interpreted as (tgt_seq_len, src_seq_len). Expect attn_weights.sum(dim=3) == 1.0.
left_cached_val : torch.Tensor[torch.float32]
The cached attention value tensor of the left context
of shape (1, left_context_len, value_dim).
Returns
-------
tuple[torch.Tensor[torch.float32], torch.Tensor[torch.float32]]
A tuple of two float tensors:
- attention weighted output of shape (1, seq_len, embed_dim).
A tensor with the same shape as input x.
- updated cached attention value tensor of the left context
of shape (1, left_context_len, value_dim).
"""
batch_size = x.size(0) # batch size is 1
num_heads = attn_weights.size(1)
x = self.in_proj(x) # (1, seq_len, num_heads * value_head_dim)
x = torch.cat((left_cached_val, x), dim=1)
# Update cached left contexts
left_cached_val = x[
:,
x.size(1) - self.right_context - left_cached_val.size(1):
x.size(1) - self.right_context,
]
x = x.reshape(batch_size, x.size(1), num_heads, x.size(2) // num_heads).permute(0, 2, 1, 3)
# (1, num_heads, seq_len, seq_len_2) x (1, num_heads, seq_len_2, value_head_dim) ->
# -> (1, num_heads, seq_len, value_head_dim)
x = torch.matmul(attn_weights, x)
# (1, num_heads, seq_len, value_head_dim) -> (1, seq_len, num_heads, value_head_dim)
x = x.permute(0, 2, 1, 3)
x = x.reshape(batch_size, x.size(1), num_heads * x.size(3))
# returned value is of shape (1, seq_len, embed_dim), like the input.
x = self.out_proj(x)
return x, left_cached_val
class FeedforwardModule(torch.nn.Module):
"""
Feedforward module in Zipformer2 encoder.
"""
def __init__(self, embed_dim: int, feedforward_dim: int, device: torch.device) -> None:
"""
FeedforwardModule initialization.
Parameters
----------
embed_dim : int
The input and output embedding dimension. The number of channels is the same for input
and output of this module.
feedforward_dim : int
The module hidden dimension.
device : torch.device
The device used to store the layer weights. should be
either torch.device("cpu") or torch.device("cuda").
"""
super().__init__()
self.in_proj = torch.nn.Linear(embed_dim, feedforward_dim, device=device)
self.activation = SwooshL()
self.out_proj = torch.nn.Linear(feedforward_dim, embed_dim, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Does a forward pass of the FeedforwardModule module.
Parameters
----------
x : torch.Tensor[torch.float32]
A float tensor of shape (1, seq_len, embed_dim). The module input.
Returns
-------
torch.Tensor[torch.float32]
A float tensor of shape (1, seq_len, embed_dim).
The module output has the same shape as input.
"""
x = self.in_proj(x)
x = self.activation(x)
x = self.out_proj(x)
return x
class NonlinAttention(torch.nn.Module):
"""
This is like the ConvolutionModule, but refactored so that we use multiplication by attention
weights (borrowed from the RelPositionMultiheadAttentionWeights module) instead of actual
convolution. We also took out the second nonlinearity, the one after the attention mechanism.
"""
def __init__(
self, embed_dim: int, att_dim: int, right_context: int, device: torch.device,
) -> None:
"""
NonlinAttention initialization.
Parameters
----------
embed_dim : int
The input and output embedding dimension. The number of channels is the same for input
and output of this module.
att_dim : int
The attention output dimension of this module.
right_context : int
The module look ahead future context, used to update left cache
correctly.
device : torch.device
The device used to store the positional embeddings.
Should be either torch.device("cpu") or torch.device("cuda").
"""
super().__init__()
self.in_proj = torch.nn.Linear(embed_dim, att_dim * 3, device=device)
self.out_proj = torch.nn.Linear(att_dim, embed_dim, device=device)
self.right_context = right_context
def forward(
self, x: torch.Tensor, attn_weights: torch.Tensor, left_cached_x: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Does a forward pass of the NonlinAttention module. Returns attention weighted input tensor
and updated attention input tensor cache of the left context.
Parameters
----------
x : torch.Tensor[torch.float32]
An input float tensor of shape (1, seq_len, embed_dim).
attn_weights : torch.Tensor[torch.float32]
A tensor of shape (1, seq_len, seq_len_2), that corresponds to a single attention head
with (seq_len, seq_len_2) being interpreted as (tgt_seq_len, src_seq_len).
Expected attn_weights.sum(dim=2) == 1.0.
Note: the first dimension here corresponds to a batch size.
left_cached_x : torch.Tensor[torch.float32]
A cached attention tensor of the left context of shape (1, left_context_len, att_dim).
Returns
-------
tuple[torch.Tensor[torch.float32], torch.Tensor[torch.float32]]
A tuple of two float tensors:
- attention weighted output of shape (1, seq_len, embed_dim).
A tensor with the same shape as input x.
- updated cached attention tensor of the left context
of shape (1, left_context_len, att_dim).
"""
x = self.in_proj(x)
s, x, y = x.chunk(3, dim=2)
x = x * torch.tanh(s)
x = torch.cat((left_cached_x, x), dim=1)
# Update cached tensor
left_cached_x = x[
:,
x.size(1) - self.right_context - left_cached_x.size(1):
x.size(1) - self.right_context,
]
# (1, seq_len, seq_len_2) x (1, seq_len_2, att_dim) -> (1, seq_len, att_dim)
x = torch.matmul(attn_weights, x)
x = x * y
x = self.out_proj(x)
return x, left_cached_x
class ConvolutionModule(torch.nn.Module):
"""
ConvolutionModule in Zipformer2 encoder.
"""
def __init__(
self, embed_dim: int, kernel_size: int, right_context: int, device: torch.device,
) -> None:
"""
ConvolutionModule initialization.
Parameters
----------
embed_dim : int
The input and output embedding dimension, also the number of channels of convolution
modules. The embedding dmension is the same for input and output of this module.
kernel_size : int
The kernel size of the depthwise convolution module.
right_context : int
The module look ahead future context, used to update
causal depthwise convolution left cache correctly.
device : torch.device
The device used to store the layer weights. Should be
either torch.device("cpu") or torch.device("cuda").
"""
super().__init__()
if kernel_size % 2 == 0:
raise ValueError(
'ConvolutionModule kernerl size should be '
f'an odd number but got {kernel_size} instead.',
)
self.in_proj = torch.nn.Linear(embed_dim, 2 * embed_dim, device=device)
self.depthwise_conv = ChunkCausalDepthwiseConv1d(
embed_dim, kernel_size, right_context, device,
)
self.activation = SwooshR()
self.out_proj = torch.nn.Linear(embed_dim, embed_dim, device=device)
def forward(
self, x: torch.Tensor, left_cache: torch.Tensor, src_key_padding_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Does a forward pass of the ConvolutionModule module. Returns processed tensor of the same
shape as input and updated cached convolution tensor of the left context.
Parameters
----------
x : torch.Tensor[torch.float32]
The input float tensor of shape (1, seq_len, embed_dim). The module input.
left_cache : torch.Tensor[torch.float32]
A cached convolution tensor of the left context
of shape (1, embed_dim, left_cache_len).
src_key_padding_mask : torch.Tensor[torch.bool]
The mask for the source keys of shape (1, seq_len),
contains True in masked positions that will be ignored.
Returns
-------
tuple[torch.Tensor[torch.float32], torch.Tensor[torch.float32]]
A tuple of two float tensors:
- module output of shape (1, seq_len, embed_dim).
A tensor with the same shape as input x.
- updated cached convolution tensor of the left context
of shape (1, embed_dim, left_cache_len).
"""
x = self.in_proj(x) # (1, seq_len, 2 * embed_dim)
x, s = x.chunk(2, dim=2)
x = x * torch.sigmoid(s) # (1, seq_len, embed_dim)
x = torch.masked_fill(x, src_key_padding_mask.unsqueeze(2), 0.0)
# exchange the temporal dimension and the feature dimension for depthwise convolution.
x = x.permute(0, 2, 1) # (1, embed_dim, seq_len).
x, left_cache = self.depthwise_conv(x, left_cache)
x = x.permute(0, 2, 1) # (1, seq_len, embed_dim)
x = self.activation(x)
x = self.out_proj(x) # (1, seq_len, embed_dim)
return x, left_cache
def _test_zipformer_main(causal: bool = False):
batch_size = 5
seq_len = 20
# Just make sure the forward pass runs.
c = Zipformer2(
encoder_dim=(64, 96),
encoder_unmasked_dim=(48, 64),
num_heads=(4, 4),
causal=causal,
chunk_size=(4,) if causal else (-1,),
left_context_frames=(64,),
)
batch_size = 5
seq_len = 20
# Just make sure the forward pass runs.
f = c(
torch.randn(seq_len, batch_size, 64),
torch.full((batch_size,), seq_len, dtype=torch.int64),
)
f[0].sum().backward()
c.eval()
f = c(
torch.randn(seq_len, batch_size, 64),
torch.full((batch_size,), seq_len, dtype=torch.int64),
)
f # to remove flake8 warnings
if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
_test_zipformer_main(False)
_test_zipformer_main(True)