mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
485 lines
18 KiB
Python
485 lines
18 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright (c) 2023 Xiaomi Corp. (author: Wei Kang)
|
|
#
|
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import copy
|
|
import math
|
|
import warnings
|
|
from typing import List, Optional, Tuple, Union
|
|
import logging
|
|
import torch
|
|
import random
|
|
from encoder_interface import EncoderInterface
|
|
from scaling import (
|
|
Balancer,
|
|
BiasNorm,
|
|
Dropout2,
|
|
ChunkCausalDepthwiseConv1d,
|
|
ActivationDropoutAndLinear,
|
|
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
|
Whiten,
|
|
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
|
penalize_abs_values_gt,
|
|
softmax,
|
|
ScheduledFloat,
|
|
FloatLike,
|
|
limit_param_value,
|
|
convert_num_channels,
|
|
)
|
|
from subformer import (
|
|
BypassModule,
|
|
CompactRelPositionalEncoding,
|
|
LearnedDownsamplingModule,
|
|
SubformerEncoder,
|
|
SubformerEncoderLayer,
|
|
)
|
|
from zipformer import (
|
|
DownsampledZipformer2Encoder,
|
|
SimpleDownsample,
|
|
SimpleUpsample,
|
|
Zipformer2Encoder,
|
|
Zipformer2EncoderLayer,
|
|
)
|
|
from torch import Tensor, nn
|
|
|
|
|
|
class Mixformer(EncoderInterface):
|
|
def __init__(
|
|
self,
|
|
structure: str = "ZZS(S(S)S)SZ",
|
|
output_downsampling_factor: int = 2,
|
|
downsampling_factor: Tuple[int] = (1, 1, 2, 2, 1),
|
|
encoder_dim: Union[int, Tuple[int]] = (
|
|
192,
|
|
192,
|
|
256,
|
|
384,
|
|
512,
|
|
384,
|
|
256,
|
|
192,
|
|
),
|
|
num_encoder_layers: Union[int, Tuple[int]] = (
|
|
2,
|
|
2,
|
|
2,
|
|
2,
|
|
2,
|
|
2,
|
|
2,
|
|
2,
|
|
),
|
|
encoder_unmasked_dim: Union[int, Tuple[int]] = (192, 192, 192),
|
|
query_head_dim: Union[int, Tuple[int]] = (32,),
|
|
value_head_dim: Union[int, Tuple[int]] = (12,),
|
|
pos_head_dim: Union[int, Tuple[int]] = (4,),
|
|
pos_dim: int = (48,),
|
|
num_heads: Union[int, Tuple[int]] = (4,),
|
|
feedforward_dim: Union[int, Tuple[int]] = (
|
|
512,
|
|
768,
|
|
1024,
|
|
1536,
|
|
2048,
|
|
1536,
|
|
1024,
|
|
768,
|
|
),
|
|
cnn_module_kernel: Union[int, Tuple[int]] = (15, 31, 31),
|
|
encoder_chunk_sizes: Tuple[Tuple[int, ...]] = ((128, 1024),),
|
|
memory_dim: int = -1,
|
|
dropout: Optional[FloatLike] = None, # see code below for default
|
|
warmup_batches: float = 4000.0,
|
|
causal: bool = False,
|
|
) -> None:
|
|
super(Mixformer, self).__init__()
|
|
|
|
if dropout is None:
|
|
dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1))
|
|
|
|
num_zip_encoders = len([s for s in structure if s == 'Z'])
|
|
num_sub_encoders = len([s for s in structure if s == 'S'])
|
|
num_encoders = num_zip_encoders + num_sub_encoders
|
|
num_downsamplers = len([s for s in structure if s == '('])
|
|
|
|
def _to_tuple(x, length):
|
|
"""Converts a single int or a 1-tuple of an int to a tuple with the same length
|
|
as downsampling_factor"""
|
|
assert isinstance(x, tuple)
|
|
if len(x) == 1:
|
|
x = x * length
|
|
else:
|
|
assert len(x) == length and isinstance(
|
|
x[0], int
|
|
)
|
|
return x
|
|
|
|
self.output_downsampling_factor = output_downsampling_factor # int
|
|
self.downsampling_factor = (
|
|
downsampling_factor
|
|
) = _to_tuple(downsampling_factor, num_zip_encoders + num_downsamplers) # tuple
|
|
self.encoder_dim = encoder_dim = _to_tuple(encoder_dim, num_encoders) # tuple
|
|
self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(
|
|
encoder_unmasked_dim, num_zip_encoders
|
|
) # tuple
|
|
num_encoder_layers = _to_tuple(num_encoder_layers, num_encoders)
|
|
self.query_head_dim = query_head_dim = _to_tuple(query_head_dim, num_encoders)
|
|
self.value_head_dim = value_head_dim = _to_tuple(value_head_dim, num_encoders)
|
|
pos_head_dim = _to_tuple(pos_head_dim, num_encoders)
|
|
pos_dim = _to_tuple(pos_dim, num_encoders)
|
|
self.num_heads = num_heads = _to_tuple(num_heads, num_encoders)
|
|
feedforward_dim = _to_tuple(feedforward_dim, num_encoders)
|
|
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(
|
|
cnn_module_kernel, num_zip_encoders
|
|
)
|
|
encoder_chunk_sizes = _to_tuple(encoder_chunk_sizes, num_sub_encoders)
|
|
|
|
self.causal = causal
|
|
|
|
# for u, d in zip(encoder_unmasked_dim, encoder_dim):
|
|
# assert u <= d
|
|
|
|
# each one will be Zipformer2Encoder, DownsampledZipformer2Encoder,
|
|
# SubformerEncoder or DownsampledSubformerEncoder
|
|
zip_encoders = []
|
|
sub_encoders = []
|
|
downsamplers = []
|
|
bypasses = []
|
|
|
|
layer_indexes = []
|
|
|
|
cur_max_dim = 0
|
|
|
|
downsampling_factors_list = []
|
|
def cur_downsampling_factor():
|
|
c = 1
|
|
for d in downsampling_factors_list: c *= d
|
|
return c
|
|
|
|
zip_encoder_dim = []
|
|
zip_downsampling_factor = []
|
|
for s in structure:
|
|
if s == "Z":
|
|
i = len(zip_encoders) + len(sub_encoders)
|
|
j = len(zip_encoders)
|
|
k = len(downsamplers) + len(zip_encoders)
|
|
assert encoder_unmasked_dim[j] <= encoder_dim[i]
|
|
zip_encoder_dim.append(encoder_dim[i])
|
|
encoder_layer = Zipformer2EncoderLayer(
|
|
embed_dim=encoder_dim[i],
|
|
pos_dim=pos_dim[i],
|
|
num_heads=num_heads[i],
|
|
query_head_dim=query_head_dim[i],
|
|
pos_head_dim=pos_head_dim[i],
|
|
value_head_dim=value_head_dim[i],
|
|
feedforward_dim=feedforward_dim[i],
|
|
dropout=dropout,
|
|
cnn_module_kernel=cnn_module_kernel[j],
|
|
causal=causal,
|
|
)
|
|
|
|
# For the segment of the warmup period, we let the Conv2dSubsampling
|
|
# layer learn something. Then we start to warm up the other encoders.
|
|
encoder = Zipformer2Encoder(
|
|
encoder_layer,
|
|
num_encoder_layers[i],
|
|
pos_dim=pos_dim[i],
|
|
dropout=dropout,
|
|
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
|
|
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
|
|
final_layerdrop_rate=0.035 * (downsampling_factor[k] ** 0.5),
|
|
)
|
|
|
|
if downsampling_factor[k] != 1:
|
|
encoder = DownsampledZipformer2Encoder(
|
|
encoder,
|
|
dim=encoder_dim[i],
|
|
downsample=downsampling_factor[k],
|
|
dropout=dropout,
|
|
)
|
|
zip_downsampling_factor.append(downsampling_factor[k])
|
|
layer_indexes.append(len(zip_encoders))
|
|
zip_encoders.append(encoder)
|
|
elif s == 'S':
|
|
i = len(zip_encoders) + len(sub_encoders)
|
|
j = len(sub_encoders)
|
|
if len(sub_encoders) == 0:
|
|
cur_max_dim = encoder_dim[i]
|
|
encoder_layer = SubformerEncoderLayer(
|
|
embed_dim=encoder_dim[i],
|
|
pos_dim=pos_head_dim[i],
|
|
num_heads=num_heads[i],
|
|
query_head_dim=query_head_dim[i],
|
|
value_head_dim=value_head_dim[i],
|
|
feedforward_dim=feedforward_dim[i],
|
|
memory_dim=memory_dim,
|
|
dropout=dropout,
|
|
causal=causal,
|
|
)
|
|
cur_max_dim = max(cur_max_dim, encoder_dim[i])
|
|
encoder = SubformerEncoder(
|
|
encoder_layer,
|
|
num_encoder_layers[i],
|
|
embed_dim=cur_max_dim,
|
|
dropout=dropout,
|
|
chunk_sizes=encoder_chunk_sizes[j],
|
|
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
|
|
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
|
|
final_layerdrop_rate=0.035 * (cur_downsampling_factor() ** 0.5),
|
|
)
|
|
layer_indexes.append(len(sub_encoders))
|
|
sub_encoders.append(encoder)
|
|
elif s =='(':
|
|
i = len(zip_encoders) + len(downsamplers)
|
|
downsampler = LearnedDownsamplingModule(cur_max_dim,
|
|
downsampling_factor[i])
|
|
downsampling_factors_list.append(downsampling_factor[i])
|
|
layer_indexes.append(len(downsamplers))
|
|
downsamplers.append(downsampler)
|
|
else:
|
|
assert s == ')'
|
|
bypass = BypassModule(cur_max_dim, straight_through_rate=0.0)
|
|
layer_indexes.append(len(bypasses))
|
|
bypasses.append(bypass)
|
|
downsampling_factors_list.pop()
|
|
|
|
logging.info(f"cur_downsampling_factor={cur_downsampling_factor()}")
|
|
|
|
self.zip_encoder_dim = zip_encoder_dim
|
|
self.zip_downsampling_factor = zip_downsampling_factor
|
|
self.layer_indexes = layer_indexes
|
|
self.structure = structure
|
|
self.zip_encoders = nn.ModuleList(zip_encoders)
|
|
self.sub_encoders = nn.ModuleList(sub_encoders)
|
|
self.downsamplers = nn.ModuleList(downsamplers)
|
|
self.bypasses = nn.ModuleList(bypasses)
|
|
|
|
self.encoder_pos = CompactRelPositionalEncoding(64, pos_head_dim[0],
|
|
dropout_rate=0.15,
|
|
length_factor=1.0)
|
|
|
|
self.downsample_output = SimpleDownsample(
|
|
max(encoder_dim),
|
|
downsample=output_downsampling_factor,
|
|
dropout=dropout,
|
|
)
|
|
|
|
def _get_full_dim_output(self, outputs: List[Tensor]):
|
|
num_encoders = len(self.zip_encoders) + 1
|
|
assert len(outputs) == num_encoders
|
|
output_dim = max(self.encoder_dim)
|
|
output_pieces = [outputs[-1]]
|
|
cur_dim = self.encoder_dim[-1]
|
|
for i in range(num_encoders - 2, -1, -1):
|
|
d = list(outputs[i].shape)[-1]
|
|
if d > cur_dim:
|
|
this_output = outputs[i]
|
|
output_pieces.append(this_output[..., cur_dim:d])
|
|
cur_dim = d
|
|
assert cur_dim == output_dim, (cur_dim, output_dim)
|
|
return torch.cat(output_pieces, dim=-1)
|
|
|
|
def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]:
|
|
"""
|
|
In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of
|
|
randomized feature masks, one per encoder.
|
|
On e.g. 15% of frames, these masks will zero out all enocder dims larger than
|
|
some supplied number, e.g. >256, so in effect on those frames we are using
|
|
a smaller encoer dim.
|
|
|
|
We generate the random masks at this level because we want the 2 masks to 'agree'
|
|
all the way up the encoder stack. This will mean that the 1st mask will have
|
|
mask values repeated self.zipformer_subsampling_factor times.
|
|
|
|
Args:
|
|
x: the embeddings (needed for the shape and dtype and device), of shape
|
|
(1, batch_size, encoder_dims0)
|
|
"""
|
|
num_encoders = len(self.zip_encoders)
|
|
if not self.training:
|
|
return [1.0] * num_encoders
|
|
|
|
(num_frames0, batch_size, _encoder_dims0) = x.shape
|
|
|
|
assert self.encoder_dim[0] == _encoder_dims0
|
|
|
|
feature_mask_dropout_prob = 0.125
|
|
|
|
# mask1 shape: (1, batch_size, 1)
|
|
mask1 = (
|
|
torch.rand(1, batch_size, 1, device=x.device)
|
|
> feature_mask_dropout_prob
|
|
).to(x.dtype)
|
|
|
|
# mask2 has additional sequences masked, about twice the number.
|
|
mask2 = torch.logical_and(
|
|
mask1,
|
|
(
|
|
torch.rand(1, batch_size, 1, device=x.device)
|
|
> feature_mask_dropout_prob
|
|
).to(x.dtype),
|
|
)
|
|
|
|
# dim: (1, batch_size, 2)
|
|
mask = torch.cat((mask1, mask2), dim=-1)
|
|
|
|
feature_masks = []
|
|
for i in range(num_encoders):
|
|
channels = self.zip_encoder_dim[i]
|
|
feature_mask = torch.ones(
|
|
1, batch_size, channels, dtype=x.dtype, device=x.device
|
|
)
|
|
u1 = self.encoder_unmasked_dim[i]
|
|
u2 = u1 + (channels - u1) // 2
|
|
|
|
feature_mask[:, :, u1:u2] *= mask[..., 0:1]
|
|
feature_mask[:, :, u2:] *= mask[..., 1:2]
|
|
|
|
feature_masks.append(feature_mask)
|
|
|
|
return feature_masks
|
|
|
|
def _get_attn_offset(self, x: Tensor, src_key_padding_mask: Optional[Tensor]) -> Optional[Tensor]:
|
|
"""
|
|
Return attention offset of shape (1 or batch_size, seq_len, seq_len), interpreted as (1 or batch_size, tgt_seq_len,
|
|
src_seq_len); this reflects masking, if causal == True, otherwise will be all zeros.
|
|
|
|
Args:
|
|
x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim).
|
|
src_key_padding_mask: optional key-padding mask of shape (batch_size, seq_len) with True in masked positions.
|
|
"""
|
|
seq_len, batch_size, _num_channels = x.shape
|
|
|
|
ans = torch.zeros(batch_size, seq_len, seq_len, device=x.device)
|
|
|
|
if self.causal:
|
|
# t is frame index, shape (seq_len,)
|
|
t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
|
|
src_t = t
|
|
tgt_t = t.unsqueeze(-1)
|
|
attn_mask = (src_t > tgt_t)
|
|
ans.masked_fill_(attn_mask, float('-inf'))
|
|
|
|
if src_key_padding_mask is not None:
|
|
ans.masked_fill_(src_key_padding_mask.unsqueeze(1), float('-inf'))
|
|
# now ans: (batch_size, seq_len, seq_len).
|
|
return ans
|
|
|
|
|
|
def forward(
|
|
self,
|
|
x: Tensor,
|
|
x_lens: Tensor,
|
|
src_key_padding_mask: Optional[Tensor] = None,
|
|
) -> Tuple[Tensor, Tensor]:
|
|
"""
|
|
Args:
|
|
x:
|
|
The input tensor. Its shape is (seq_len, batch_size, feature_dim).
|
|
x_lens:
|
|
A tensor of shape (batch_size,) containing the number of frames in
|
|
`x` before padding.
|
|
src_key_padding_mask:
|
|
The mask for padding, of shape (batch_size, seq_len); True means
|
|
masked position. May be None.
|
|
Returns:
|
|
Return a tuple containing 2 tensors:
|
|
- embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim))
|
|
- lengths, a tensor of shape (batch_size,) containing the number
|
|
of frames in `embeddings` before padding.
|
|
"""
|
|
outputs = []
|
|
|
|
attn_offsets = [ self._get_attn_offset(x, src_key_padding_mask) ]
|
|
pos_embs = [ self.encoder_pos(x) ]
|
|
downsample_info = []
|
|
|
|
if torch.jit.is_scripting():
|
|
feature_masks = [1.0] * len(self.zip_encoders)
|
|
else:
|
|
feature_masks = self.get_feature_masks(x)
|
|
|
|
for s, i in zip(self.structure, self.layer_indexes):
|
|
if s == 'Z':
|
|
encoder = self.zip_encoders[i]
|
|
ds = self.zip_downsampling_factor[i]
|
|
x = convert_num_channels(x, self.zip_encoder_dim[i])
|
|
x = encoder(
|
|
x,
|
|
feature_mask=feature_masks[i],
|
|
src_key_padding_mask=(
|
|
None
|
|
if src_key_padding_mask is None
|
|
else src_key_padding_mask[..., ::ds]
|
|
),
|
|
)
|
|
outputs.append(x)
|
|
elif s == 'S':
|
|
encoder = self.sub_encoders[i] # one encoder stack
|
|
x = encoder(x,
|
|
pos_embs[-1],
|
|
attn_offset=attn_offsets[-1])
|
|
|
|
# only the last output of subformer will be used to combine the
|
|
# final output.
|
|
if i == len(self.sub_encoders) - 1:
|
|
outputs.append(x)
|
|
# x will have the maximum dimension up till now, even if
|
|
# `encoder` uses lower dim in its layers.
|
|
elif s == '(':
|
|
downsampler = self.downsamplers[i]
|
|
|
|
indexes, weights, x_new = downsampler(x)
|
|
downsample_info.append((indexes, weights, x))
|
|
x = x_new
|
|
|
|
pos_embs.append(downsampler.downsample_pos_emb(pos_embs[-1], indexes))
|
|
|
|
attn_offsets.append(downsampler.downsample_attn_offset(attn_offsets[-1],
|
|
indexes,
|
|
weights))
|
|
else:
|
|
assert s == ')' # upsample and bypass
|
|
indexes, weights, x_orig = downsample_info.pop()
|
|
_attn_offset = attn_offsets.pop()
|
|
_pos_emb = pos_embs.pop()
|
|
x_orig = convert_num_channels(x_orig, x.shape[-1])
|
|
|
|
x = LearnedDownsamplingModule.upsample(x_orig, x, indexes, weights)
|
|
|
|
bypass = self.bypasses[i]
|
|
x = bypass(x_orig, x)
|
|
|
|
# Only "balanced" structure is supported now
|
|
assert len(downsample_info) == 0, len(downsample_info)
|
|
|
|
# if the last output has the largest dimension, x will be unchanged,
|
|
# it will be the same as outputs[-1]. Otherwise it will be concatenated
|
|
# from different pieces of 'outputs', taking each dimension from the
|
|
# most recent output that has it present.
|
|
x = self._get_full_dim_output(outputs)
|
|
x = self.downsample_output(x)
|
|
# class Downsample has this rounding behavior..
|
|
assert self.output_downsampling_factor == 2
|
|
if torch.jit.is_scripting():
|
|
lengths = (x_lens + 1) // 2
|
|
else:
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter("ignore")
|
|
lengths = (x_lens + 1) // 2
|
|
|
|
return x, lengths
|