2023-08-02 14:43:23 +08:00

485 lines
18 KiB
Python

#!/usr/bin/env python3
# Copyright (c) 2023 Xiaomi Corp. (author: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import math
import warnings
from typing import List, Optional, Tuple, Union
import logging
import torch
import random
from encoder_interface import EncoderInterface
from scaling import (
Balancer,
BiasNorm,
Dropout2,
ChunkCausalDepthwiseConv1d,
ActivationDropoutAndLinear,
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
Whiten,
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
penalize_abs_values_gt,
softmax,
ScheduledFloat,
FloatLike,
limit_param_value,
convert_num_channels,
)
from subformer import (
BypassModule,
CompactRelPositionalEncoding,
LearnedDownsamplingModule,
SubformerEncoder,
SubformerEncoderLayer,
)
from zipformer import (
DownsampledZipformer2Encoder,
SimpleDownsample,
SimpleUpsample,
Zipformer2Encoder,
Zipformer2EncoderLayer,
)
from torch import Tensor, nn
class Mixformer(EncoderInterface):
def __init__(
self,
structure: str = "ZZS(S(S)S)SZ",
output_downsampling_factor: int = 2,
downsampling_factor: Tuple[int] = (1, 1, 2, 2, 1),
encoder_dim: Union[int, Tuple[int]] = (
192,
192,
256,
384,
512,
384,
256,
192,
),
num_encoder_layers: Union[int, Tuple[int]] = (
2,
2,
2,
2,
2,
2,
2,
2,
),
encoder_unmasked_dim: Union[int, Tuple[int]] = (192, 192, 192),
query_head_dim: Union[int, Tuple[int]] = (32,),
value_head_dim: Union[int, Tuple[int]] = (12,),
pos_head_dim: Union[int, Tuple[int]] = (4,),
pos_dim: int = (48,),
num_heads: Union[int, Tuple[int]] = (4,),
feedforward_dim: Union[int, Tuple[int]] = (
512,
768,
1024,
1536,
2048,
1536,
1024,
768,
),
cnn_module_kernel: Union[int, Tuple[int]] = (15, 31, 31),
encoder_chunk_sizes: Tuple[Tuple[int, ...]] = ((128, 1024),),
memory_dim: int = -1,
dropout: Optional[FloatLike] = None, # see code below for default
warmup_batches: float = 4000.0,
causal: bool = False,
) -> None:
super(Mixformer, self).__init__()
if dropout is None:
dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1))
num_zip_encoders = len([s for s in structure if s == 'Z'])
num_sub_encoders = len([s for s in structure if s == 'S'])
num_encoders = num_zip_encoders + num_sub_encoders
num_downsamplers = len([s for s in structure if s == '('])
def _to_tuple(x, length):
"""Converts a single int or a 1-tuple of an int to a tuple with the same length
as downsampling_factor"""
assert isinstance(x, tuple)
if len(x) == 1:
x = x * length
else:
assert len(x) == length and isinstance(
x[0], int
)
return x
self.output_downsampling_factor = output_downsampling_factor # int
self.downsampling_factor = (
downsampling_factor
) = _to_tuple(downsampling_factor, num_zip_encoders + num_downsamplers) # tuple
self.encoder_dim = encoder_dim = _to_tuple(encoder_dim, num_encoders) # tuple
self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(
encoder_unmasked_dim, num_zip_encoders
) # tuple
num_encoder_layers = _to_tuple(num_encoder_layers, num_encoders)
self.query_head_dim = query_head_dim = _to_tuple(query_head_dim, num_encoders)
self.value_head_dim = value_head_dim = _to_tuple(value_head_dim, num_encoders)
pos_head_dim = _to_tuple(pos_head_dim, num_encoders)
pos_dim = _to_tuple(pos_dim, num_encoders)
self.num_heads = num_heads = _to_tuple(num_heads, num_encoders)
feedforward_dim = _to_tuple(feedforward_dim, num_encoders)
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(
cnn_module_kernel, num_zip_encoders
)
encoder_chunk_sizes = _to_tuple(encoder_chunk_sizes, num_sub_encoders)
self.causal = causal
# for u, d in zip(encoder_unmasked_dim, encoder_dim):
# assert u <= d
# each one will be Zipformer2Encoder, DownsampledZipformer2Encoder,
# SubformerEncoder or DownsampledSubformerEncoder
zip_encoders = []
sub_encoders = []
downsamplers = []
bypasses = []
layer_indexes = []
cur_max_dim = 0
downsampling_factors_list = []
def cur_downsampling_factor():
c = 1
for d in downsampling_factors_list: c *= d
return c
zip_encoder_dim = []
zip_downsampling_factor = []
for s in structure:
if s == "Z":
i = len(zip_encoders) + len(sub_encoders)
j = len(zip_encoders)
k = len(downsamplers) + len(zip_encoders)
assert encoder_unmasked_dim[j] <= encoder_dim[i]
zip_encoder_dim.append(encoder_dim[i])
encoder_layer = Zipformer2EncoderLayer(
embed_dim=encoder_dim[i],
pos_dim=pos_dim[i],
num_heads=num_heads[i],
query_head_dim=query_head_dim[i],
pos_head_dim=pos_head_dim[i],
value_head_dim=value_head_dim[i],
feedforward_dim=feedforward_dim[i],
dropout=dropout,
cnn_module_kernel=cnn_module_kernel[j],
causal=causal,
)
# For the segment of the warmup period, we let the Conv2dSubsampling
# layer learn something. Then we start to warm up the other encoders.
encoder = Zipformer2Encoder(
encoder_layer,
num_encoder_layers[i],
pos_dim=pos_dim[i],
dropout=dropout,
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
final_layerdrop_rate=0.035 * (downsampling_factor[k] ** 0.5),
)
if downsampling_factor[k] != 1:
encoder = DownsampledZipformer2Encoder(
encoder,
dim=encoder_dim[i],
downsample=downsampling_factor[k],
dropout=dropout,
)
zip_downsampling_factor.append(downsampling_factor[k])
layer_indexes.append(len(zip_encoders))
zip_encoders.append(encoder)
elif s == 'S':
i = len(zip_encoders) + len(sub_encoders)
j = len(sub_encoders)
if len(sub_encoders) == 0:
cur_max_dim = encoder_dim[i]
encoder_layer = SubformerEncoderLayer(
embed_dim=encoder_dim[i],
pos_dim=pos_head_dim[i],
num_heads=num_heads[i],
query_head_dim=query_head_dim[i],
value_head_dim=value_head_dim[i],
feedforward_dim=feedforward_dim[i],
memory_dim=memory_dim,
dropout=dropout,
causal=causal,
)
cur_max_dim = max(cur_max_dim, encoder_dim[i])
encoder = SubformerEncoder(
encoder_layer,
num_encoder_layers[i],
embed_dim=cur_max_dim,
dropout=dropout,
chunk_sizes=encoder_chunk_sizes[j],
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
final_layerdrop_rate=0.035 * (cur_downsampling_factor() ** 0.5),
)
layer_indexes.append(len(sub_encoders))
sub_encoders.append(encoder)
elif s =='(':
i = len(zip_encoders) + len(downsamplers)
downsampler = LearnedDownsamplingModule(cur_max_dim,
downsampling_factor[i])
downsampling_factors_list.append(downsampling_factor[i])
layer_indexes.append(len(downsamplers))
downsamplers.append(downsampler)
else:
assert s == ')'
bypass = BypassModule(cur_max_dim, straight_through_rate=0.0)
layer_indexes.append(len(bypasses))
bypasses.append(bypass)
downsampling_factors_list.pop()
logging.info(f"cur_downsampling_factor={cur_downsampling_factor()}")
self.zip_encoder_dim = zip_encoder_dim
self.zip_downsampling_factor = zip_downsampling_factor
self.layer_indexes = layer_indexes
self.structure = structure
self.zip_encoders = nn.ModuleList(zip_encoders)
self.sub_encoders = nn.ModuleList(sub_encoders)
self.downsamplers = nn.ModuleList(downsamplers)
self.bypasses = nn.ModuleList(bypasses)
self.encoder_pos = CompactRelPositionalEncoding(64, pos_head_dim[0],
dropout_rate=0.15,
length_factor=1.0)
self.downsample_output = SimpleDownsample(
max(encoder_dim),
downsample=output_downsampling_factor,
dropout=dropout,
)
def _get_full_dim_output(self, outputs: List[Tensor]):
num_encoders = len(self.zip_encoders) + 1
assert len(outputs) == num_encoders
output_dim = max(self.encoder_dim)
output_pieces = [outputs[-1]]
cur_dim = self.encoder_dim[-1]
for i in range(num_encoders - 2, -1, -1):
d = list(outputs[i].shape)[-1]
if d > cur_dim:
this_output = outputs[i]
output_pieces.append(this_output[..., cur_dim:d])
cur_dim = d
assert cur_dim == output_dim, (cur_dim, output_dim)
return torch.cat(output_pieces, dim=-1)
def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]:
"""
In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of
randomized feature masks, one per encoder.
On e.g. 15% of frames, these masks will zero out all enocder dims larger than
some supplied number, e.g. >256, so in effect on those frames we are using
a smaller encoer dim.
We generate the random masks at this level because we want the 2 masks to 'agree'
all the way up the encoder stack. This will mean that the 1st mask will have
mask values repeated self.zipformer_subsampling_factor times.
Args:
x: the embeddings (needed for the shape and dtype and device), of shape
(1, batch_size, encoder_dims0)
"""
num_encoders = len(self.zip_encoders)
if not self.training:
return [1.0] * num_encoders
(num_frames0, batch_size, _encoder_dims0) = x.shape
assert self.encoder_dim[0] == _encoder_dims0
feature_mask_dropout_prob = 0.125
# mask1 shape: (1, batch_size, 1)
mask1 = (
torch.rand(1, batch_size, 1, device=x.device)
> feature_mask_dropout_prob
).to(x.dtype)
# mask2 has additional sequences masked, about twice the number.
mask2 = torch.logical_and(
mask1,
(
torch.rand(1, batch_size, 1, device=x.device)
> feature_mask_dropout_prob
).to(x.dtype),
)
# dim: (1, batch_size, 2)
mask = torch.cat((mask1, mask2), dim=-1)
feature_masks = []
for i in range(num_encoders):
channels = self.zip_encoder_dim[i]
feature_mask = torch.ones(
1, batch_size, channels, dtype=x.dtype, device=x.device
)
u1 = self.encoder_unmasked_dim[i]
u2 = u1 + (channels - u1) // 2
feature_mask[:, :, u1:u2] *= mask[..., 0:1]
feature_mask[:, :, u2:] *= mask[..., 1:2]
feature_masks.append(feature_mask)
return feature_masks
def _get_attn_offset(self, x: Tensor, src_key_padding_mask: Optional[Tensor]) -> Optional[Tensor]:
"""
Return attention offset of shape (1 or batch_size, seq_len, seq_len), interpreted as (1 or batch_size, tgt_seq_len,
src_seq_len); this reflects masking, if causal == True, otherwise will be all zeros.
Args:
x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim).
src_key_padding_mask: optional key-padding mask of shape (batch_size, seq_len) with True in masked positions.
"""
seq_len, batch_size, _num_channels = x.shape
ans = torch.zeros(batch_size, seq_len, seq_len, device=x.device)
if self.causal:
# t is frame index, shape (seq_len,)
t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
src_t = t
tgt_t = t.unsqueeze(-1)
attn_mask = (src_t > tgt_t)
ans.masked_fill_(attn_mask, float('-inf'))
if src_key_padding_mask is not None:
ans.masked_fill_(src_key_padding_mask.unsqueeze(1), float('-inf'))
# now ans: (batch_size, seq_len, seq_len).
return ans
def forward(
self,
x: Tensor,
x_lens: Tensor,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""
Args:
x:
The input tensor. Its shape is (seq_len, batch_size, feature_dim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
src_key_padding_mask:
The mask for padding, of shape (batch_size, seq_len); True means
masked position. May be None.
Returns:
Return a tuple containing 2 tensors:
- embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim))
- lengths, a tensor of shape (batch_size,) containing the number
of frames in `embeddings` before padding.
"""
outputs = []
attn_offsets = [ self._get_attn_offset(x, src_key_padding_mask) ]
pos_embs = [ self.encoder_pos(x) ]
downsample_info = []
if torch.jit.is_scripting():
feature_masks = [1.0] * len(self.zip_encoders)
else:
feature_masks = self.get_feature_masks(x)
for s, i in zip(self.structure, self.layer_indexes):
if s == 'Z':
encoder = self.zip_encoders[i]
ds = self.zip_downsampling_factor[i]
x = convert_num_channels(x, self.zip_encoder_dim[i])
x = encoder(
x,
feature_mask=feature_masks[i],
src_key_padding_mask=(
None
if src_key_padding_mask is None
else src_key_padding_mask[..., ::ds]
),
)
outputs.append(x)
elif s == 'S':
encoder = self.sub_encoders[i] # one encoder stack
x = encoder(x,
pos_embs[-1],
attn_offset=attn_offsets[-1])
# only the last output of subformer will be used to combine the
# final output.
if i == len(self.sub_encoders) - 1:
outputs.append(x)
# x will have the maximum dimension up till now, even if
# `encoder` uses lower dim in its layers.
elif s == '(':
downsampler = self.downsamplers[i]
indexes, weights, x_new = downsampler(x)
downsample_info.append((indexes, weights, x))
x = x_new
pos_embs.append(downsampler.downsample_pos_emb(pos_embs[-1], indexes))
attn_offsets.append(downsampler.downsample_attn_offset(attn_offsets[-1],
indexes,
weights))
else:
assert s == ')' # upsample and bypass
indexes, weights, x_orig = downsample_info.pop()
_attn_offset = attn_offsets.pop()
_pos_emb = pos_embs.pop()
x_orig = convert_num_channels(x_orig, x.shape[-1])
x = LearnedDownsamplingModule.upsample(x_orig, x, indexes, weights)
bypass = self.bypasses[i]
x = bypass(x_orig, x)
# Only "balanced" structure is supported now
assert len(downsample_info) == 0, len(downsample_info)
# if the last output has the largest dimension, x will be unchanged,
# it will be the same as outputs[-1]. Otherwise it will be concatenated
# from different pieces of 'outputs', taking each dimension from the
# most recent output that has it present.
x = self._get_full_dim_output(outputs)
x = self.downsample_output(x)
# class Downsample has this rounding behavior..
assert self.output_downsampling_factor == 2
if torch.jit.is_scripting():
lengths = (x_lens + 1) // 2
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
lengths = (x_lens + 1) // 2
return x, lengths