diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 64fde095b..d011e0ff3 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -16,16 +16,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, List +from typing import List, Optional, Tuple import k2 import torch import torch.nn as nn from encoder_interface import EncoderInterface from multi_quantization.prediction import JointCodebookLoss +from scaling import ScaledLinear from icefall.utils import add_sos, make_pad_mask -from scaling import ScaledLinear class AsrModel(nn.Module): @@ -118,7 +118,7 @@ class AsrModel(nn.Module): nn.Linear(encoder_dim, vocab_size), nn.LogSoftmax(dim=-1), ) - + if num_codebooks > 0: self.codebook_loss_net = JointCodebookLoss( predictor_channels=cb_input_dim, @@ -141,8 +141,8 @@ class AsrModel(nn.Module): Encoder output, of shape (N, T, C). encoder_out_lens: Encoder output lengths, of shape (N,). - saved_embeddings: - The embeddings from the middle layers + saved_embeddings: + The embeddings from the middle layers """ # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") x, x_lens = self.encoder_embed(x, x_lens) @@ -151,7 +151,9 @@ class AsrModel(nn.Module): src_key_padding_mask = make_pad_mask(x_lens) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_out, encoder_out_lens, middle_out = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out, encoder_out_lens, middle_out = self.encoder( + x, x_lens, src_key_padding_mask + ) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) @@ -324,9 +326,9 @@ class AsrModel(nn.Module): lm_scale: The scale to smooth the loss with lm (output of predictor network) part - codebook_indexes: - The codebook indexes to be predicted. Only used when doing knowledge - distillation with MVQ + codebook_indexes: + The codebook indexes to be predicted. Only used when doing knowledge + distillation with MVQ Returns: Return the transducer losses and CTC loss, and potentially codebook loss in form of (simple_loss, pruned_loss, ctc_loss, codebook_loss) @@ -375,18 +377,18 @@ class AsrModel(nn.Module): ) else: ctc_loss = torch.empty(0) - + if self.training and hasattr(self, "codebook_loss_net"): assert codebook_indexes is not None codebook_loss = self.forward_codebook( - middle_out=middle_out, - codebook_indexes=codebook_indexes, - ) + middle_out=middle_out, + codebook_indexes=codebook_indexes, + ) else: codebook_loss = torch.empty(0) return simple_loss, pruned_loss, ctc_loss, codebook_loss - + def forward_codebook( self, middle_out: List[torch.Tensor], @@ -394,21 +396,23 @@ class AsrModel(nn.Module): ) -> torch.Tensor: """Calculate the codebook loss for the model (knowledge distillation) - Args: - middle_out (List[torch.Tensor]): - The embeddings extracted from the middle layer of the zipformer encoder - codebook_indexes (torch.Tensor): - The encoded codebook indexes for knowledge distillation + Args: + middle_out (List[torch.Tensor]): + The embeddings extracted from the middle layer of the zipformer encoder + codebook_indexes (torch.Tensor): + The encoded codebook indexes for knowledge distillation - Returns: - The codebook loss value - """ - middle_layer_output = middle_out[0] # currently only support using output of one layer, (N,T,C) + Returns: + The codebook loss value + """ + middle_layer_output = middle_out[ + 0 + ] # currently only support using output of one layer, (N,T,C) len_CI = codebook_indexes.size(1) len_mid_layer = middle_layer_output.size(1) - ratio = round(len_CI/len_mid_layer) - - if ratio == 1: # Having the same frame rate + ratio = round(len_CI / len_mid_layer) + + if ratio == 1: # Having the same frame rate assert len_CI > len_mid_layer, (len_CI, len_mid_layer) codebook_indexes = codebook_indexes[:, :len_mid_layer, :] assert codebook_indexes.size(1) == middle_layer_output.size(1) @@ -422,13 +426,11 @@ class AsrModel(nn.Module): codebook_loss = self.codebook_loss_net( middle_layer_output, codebook_indexes ) - + return codebook_loss - + @staticmethod - def concat_successive_codebook_indexes( - middle_layer_output, codebook_indexes - ): + def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes): # Output rate of hubert is 50 frames per second, # while that of current encoder is 25. # Following code handling two issues: @@ -448,10 +450,12 @@ class AsrModel(nn.Module): # Handling issue 1. if T >= t_expected * 2: codebook_indexes = codebook_indexes[:, : t_expected * 2, :] - if T / t_expected < 1.1: # To be changed, dirty hack to jump out of this function - codebook_indexes = codebook_indexes[:, : t_expected, :] - assert middle_layer_output.shape[1] == codebook_indexes.shape[1] - return codebook_indexes + if ( + T / t_expected < 1.1 + ): # To be changed, dirty hack to jump out of this function + codebook_indexes = codebook_indexes[:, :t_expected, :] + assert middle_layer_output.shape[1] == codebook_indexes.shape[1] + return codebook_indexes # Handling issue 2. codebook_indexes = codebook_indexes.reshape(N, t_expected, C * 2) assert middle_layer_output.shape[1] == codebook_indexes.shape[1] diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index f5142650b..a6cb48205 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -403,28 +403,28 @@ def get_parser(): default=0.2, help="Scale for CTC loss.", ) - + parser.add_argument( "--enable-distillation", type=str2bool, default=True, help="Whether to eanble distillation.", ) - + parser.add_argument( "--codebook-loss-scale", type=float, default=0.1, help="The scale of codebook loss.", ) - + parser.add_argument( "--num-codebooks", type=int, default=16, help="Number of codebooks used for the extracted CI", ) - + parser.add_argument( "--distillation-layer", type=int, @@ -636,11 +636,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: def get_model(params: AttributeDict) -> nn.Module: - assert ( - params.use_transducer or params.use_ctc - ), (f"At least one of them should be True, " + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " f"but got params.use_transducer={params.use_transducer}, " - f"params.use_ctc={params.use_ctc}") + f"params.use_ctc={params.use_ctc}" + ) encoder_embed = get_encoder_embed(params) encoder = get_encoder_model(params) @@ -783,17 +783,17 @@ def save_checkpoint( best_valid_filename = params.exp_dir / "best-valid-loss.pt" copyfile(src=filename, dst=best_valid_filename) + def extract_codebook_indexes(batch: Dict) -> Tuple[Tensor, Tensor]: cuts = batch["supervisions"]["cut"] # -100 is identical to ignore_value in CE loss computation. - cuts_pre_mixed = [ - c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts - ] + cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] codebook_indexes, codebook_indexes_lens = collate_custom_field( cuts_pre_mixed, "codebook_indexes", pad_value=-100 ) return codebook_indexes, codebook_indexes_lens + def compute_loss( params: AttributeDict, model: Union[nn.Module, DDP], @@ -834,7 +834,7 @@ def compute_loss( texts = batch["supervisions"]["text"] y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y) - + if is_training and params.enable_distillation: codebook_indexes, _ = extract_codebook_indexes(batch) codebook_indexes = codebook_indexes.to(device) @@ -859,21 +859,20 @@ def compute_loss( # take down the scale on the simple loss from 1.0 at the start # to params.simple_loss scale by warm_step. simple_loss_scale = ( - s if batch_idx_train >= warm_step + s + if batch_idx_train >= warm_step else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) ) pruned_loss_scale = ( - 1.0 if batch_idx_train >= warm_step + 1.0 + if batch_idx_train >= warm_step else 0.1 + 0.9 * (batch_idx_train / warm_step) ) - loss += ( - simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss - ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss if params.use_ctc: loss += params.ctc_loss_scale * ctc_loss - + if is_training and params.enable_distillation: loss += params.codebook_loss_scale * codebook_loss @@ -1164,7 +1163,9 @@ def run(rank, world_size, args): # Note: it's better to set --spec-aug-time-warpi-factor=-1 # when doing distillation with vq. if params.enable_distillation: - assert args.spec_aug_time_warp_factor < 1, "Specaug should be disabled during distillation" + assert ( + args.spec_aug_time_warp_factor < 1 + ), "Specaug should be disabled during distillation" device = torch.device("cpu") if torch.cuda.is_available(): diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7897057ff..ede30451a 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -17,28 +17,33 @@ # limitations under the License. import copy +import logging import math +import random import warnings from typing import List, Optional, Tuple, Union -import logging + import torch -import random from encoder_interface import EncoderInterface from scaling import ( + Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. +) +from scaling import ( + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. +) +from scaling import ( + ActivationDropoutAndLinear, Balancer, BiasNorm, - Dropout2, ChunkCausalDepthwiseConv1d, - ActivationDropoutAndLinear, - ScaledLinear, # not as in other dirs.. just scales down initial parameter values. + Dropout2, + FloatLike, + ScheduledFloat, Whiten, - Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. + convert_num_channels, + limit_param_value, penalize_abs_values_gt, softmax, - ScheduledFloat, - FloatLike, - limit_param_value, - convert_num_channels, ) from torch import Tensor, nn @@ -93,35 +98,35 @@ class Zipformer2(EncoderInterface): middle_output_layer: Get the output of a middle layer of the model """ + def __init__( - self, - output_downsampling_factor: int = 2, - downsampling_factor: Tuple[int] = (2, 4), - encoder_dim: Union[int, Tuple[int]] = 384, - num_encoder_layers: Union[int, Tuple[int]] = 4, - encoder_unmasked_dim: Union[int, Tuple[int]] = 256, - query_head_dim: Union[int, Tuple[int]] = 24, - pos_head_dim: Union[int, Tuple[int]] = 4, - value_head_dim: Union[int, Tuple[int]] = 12, - num_heads: Union[int, Tuple[int]] = 8, - feedforward_dim: Union[int, Tuple[int]] = 1536, - cnn_module_kernel: Union[int, Tuple[int]] = 31, - pos_dim: int = 192, - dropout: FloatLike = None, # see code below for default - warmup_batches: float = 4000.0, - causal: bool = False, - chunk_size: Tuple[int] = [-1], - left_context_frames: Tuple[int] = [-1], - middle_output_layer: int = None # 0-based layer index + self, + output_downsampling_factor: int = 2, + downsampling_factor: Tuple[int] = (2, 4), + encoder_dim: Union[int, Tuple[int]] = 384, + num_encoder_layers: Union[int, Tuple[int]] = 4, + encoder_unmasked_dim: Union[int, Tuple[int]] = 256, + query_head_dim: Union[int, Tuple[int]] = 24, + pos_head_dim: Union[int, Tuple[int]] = 4, + value_head_dim: Union[int, Tuple[int]] = 12, + num_heads: Union[int, Tuple[int]] = 8, + feedforward_dim: Union[int, Tuple[int]] = 1536, + cnn_module_kernel: Union[int, Tuple[int]] = 31, + pos_dim: int = 192, + dropout: FloatLike = None, # see code below for default + warmup_batches: float = 4000.0, + causal: bool = False, + chunk_size: Tuple[int] = [-1], + left_context_frames: Tuple[int] = [-1], + middle_output_layer: int = None, # 0-based layer index ) -> None: super(Zipformer2, self).__init__() if dropout is None: - dropout = ScheduledFloat((0.0, 0.3), - (20000.0, 0.1)) + dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) def _to_tuple(x): - """ Converts a single int or a 1-tuple of an int to a tuple with the same length + """Converts a single int or a 1-tuple of an int to a tuple with the same length as downsampling_factor""" if isinstance(x, int): x = (x,) @@ -131,10 +136,12 @@ class Zipformer2(EncoderInterface): assert len(x) == len(downsampling_factor) and isinstance(x[0], int) return x - self.output_downsampling_factor = output_downsampling_factor # int - self.downsampling_factor = downsampling_factor # tuple - self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple - self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(encoder_unmasked_dim) # tuple + self.output_downsampling_factor = output_downsampling_factor # int + self.downsampling_factor = downsampling_factor # tuple + self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple + self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple( + encoder_unmasked_dim + ) # tuple num_encoder_layers = _to_tuple(num_encoder_layers) self.num_encoder_layers = num_encoder_layers self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) @@ -148,7 +155,7 @@ class Zipformer2(EncoderInterface): self.chunk_size = chunk_size self.left_context_frames = left_context_frames - for u,d in zip(encoder_unmasked_dim, encoder_dim): + for u, d in zip(encoder_unmasked_dim, encoder_dim): assert u <= d # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder @@ -193,25 +200,22 @@ class Zipformer2(EncoderInterface): encoders.append(encoder) self.encoders = nn.ModuleList(encoders) - + # for mvq: return the middle layer output output_layers = [] if middle_output_layer is not None: - assert ( - middle_output_layer >= 0 - and middle_output_layer < len(num_encoder_layers) + assert middle_output_layer >= 0 and middle_output_layer < len( + num_encoder_layers ) output_layers.append(middle_output_layer) - - self.output_layers = output_layers # A list of int - self.downsample_output = SimpleDownsample(max(encoder_dim), - downsample=output_downsampling_factor, - dropout=dropout) + self.output_layers = output_layers # A list of int - def get_feature_masks( - self, - x: Tensor) -> Union[List[float], List[Tensor]]: + self.downsample_output = SimpleDownsample( + max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout + ) + + def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]: """ In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of randomized feature masks, one per encoder. @@ -229,24 +233,30 @@ class Zipformer2(EncoderInterface): """ num_encoders = len(self.encoder_dim) if not self.training: - return [ 1.0 ] * num_encoders + return [1.0] * num_encoders (num_frames0, batch_size, _encoder_dims0) = x.shape - assert self.encoder_dim[0] == _encoder_dims0, (self.encoder_dim[0], _encoder_dims0) + assert self.encoder_dim[0] == _encoder_dims0, ( + self.encoder_dim[0], + _encoder_dims0, + ) feature_mask_dropout_prob = 0.125 # mask1 shape: (1, batch_size, 1) - mask1 = (torch.rand(1, batch_size, 1, - device=x.device) > - feature_mask_dropout_prob).to(x.dtype) + mask1 = ( + torch.rand(1, batch_size, 1, device=x.device) > feature_mask_dropout_prob + ).to(x.dtype) # mask2 has additional sequences masked, about twice the number. - mask2 = torch.logical_and(mask1, - (torch.rand(1, batch_size, 1, - device=x.device) > - feature_mask_dropout_prob).to(x.dtype)) + mask2 = torch.logical_and( + mask1, + ( + torch.rand(1, batch_size, 1, device=x.device) + > feature_mask_dropout_prob + ).to(x.dtype), + ) # dim: (1, batch_size, 2) mask = torch.cat((mask1, mask2), dim=-1) @@ -254,8 +264,9 @@ class Zipformer2(EncoderInterface): feature_masks = [] for i in range(num_encoders): channels = self.encoder_dim[i] - feature_mask = torch.ones(1, batch_size, channels, - dtype=x.dtype, device=x.device) + feature_mask = torch.ones( + 1, batch_size, channels, dtype=x.dtype, device=x.device + ) u1 = self.encoder_unmasked_dim[i] u2 = u1 + (channels - u1) // 2 @@ -295,7 +306,8 @@ class Zipformer2(EncoderInterface): return chunk_size, left_context_chunks def forward( - self, x: Tensor, + self, + x: Tensor, x_lens: Tensor, src_key_padding_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: @@ -333,12 +345,17 @@ class Zipformer2(EncoderInterface): ds = self.downsampling_factor[i] x = convert_num_channels(x, self.encoder_dim[i]) - x = module(x, - chunk_size=chunk_size, - feature_mask=feature_masks[i], - src_key_padding_mask=(None if src_key_padding_mask is None - else src_key_padding_mask[...,::ds]), - attn_mask=attn_mask) + x = module( + x, + chunk_size=chunk_size, + feature_mask=feature_masks[i], + src_key_padding_mask=( + None + if src_key_padding_mask is None + else src_key_padding_mask[..., ::ds] + ), + attn_mask=attn_mask, + ) outputs.append(x) # if the last output has the largest dimension, x will be unchanged, @@ -348,9 +365,11 @@ class Zipformer2(EncoderInterface): x = self._get_full_dim_output(outputs) x = self.downsample_output(x) # class Downsample has this rounding behavior.. - - saved = [outputs[i].permute(1,0,2) for i in self.output_layers] # collect the embeddings - + + saved = [ + outputs[i].permute(1, 0, 2) for i in self.output_layers + ] # collect the embeddings + assert self.output_downsampling_factor == 2, self.output_downsampling_factor if torch.jit.is_scripting() or torch.jit.is_tracing(): lengths = (x_lens + 1) // 2 @@ -362,9 +381,7 @@ class Zipformer2(EncoderInterface): return x, lengths, saved def _get_attn_mask( - self, x: Tensor, - chunk_size: int, - left_context_chunks: int + self, x: Tensor, chunk_size: int, left_context_chunks: int ) -> Optional[Tensor]: """ Return None if chunk_size == -1, else return attention mask of shape @@ -379,9 +396,11 @@ class Zipformer2(EncoderInterface): assert all(chunk_size % d == 0 for d in self.downsampling_factor) if left_context_chunks >= 0: num_encoders = len(self.encoder_dim) - assert all (chunk_size * left_context_chunks >= - (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i] - for i in range(num_encoders)) + assert all( + chunk_size * left_context_chunks + >= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i] + for i in range(num_encoders) + ) else: left_context_chunks = 1000000 @@ -399,8 +418,7 @@ class Zipformer2(EncoderInterface): src_c = c tgt_c = c.unsqueeze(-1) - attn_mask = torch.logical_or(src_c > tgt_c, - src_c < tgt_c - left_context_chunks) + attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks) if __name__ == "__main__": logging.info(f"attn_mask = {attn_mask}") return attn_mask @@ -409,7 +427,7 @@ class Zipformer2(EncoderInterface): num_encoders = len(self.encoder_dim) assert len(outputs) == num_encoders output_dim = max(self.encoder_dim) - output_pieces = [ outputs[-1] ] + output_pieces = [outputs[-1]] cur_dim = self.encoder_dim[-1] for i in range(num_encoders - 2, -1, -1): d = self.encoder_dim[i] @@ -506,21 +524,38 @@ class Zipformer2(EncoderInterface): nonlin_attn_head_dim = 3 * embed_dim // 4 conv_left_pad = self.cnn_module_kernel[i] // 2 for layer in range(num_layers): - cached_key = torch.zeros(downsample_left, batch_size, key_dim).to(device) - cached_nonlin_attn = torch.zeros(1, batch_size, downsample_left, nonlin_attn_head_dim).to(device) - cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to(device) - cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to(device) - cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(device) - cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(device) - states += [cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2] + cached_key = torch.zeros(downsample_left, batch_size, key_dim).to( + device + ) + cached_nonlin_attn = torch.zeros( + 1, batch_size, downsample_left, nonlin_attn_head_dim + ).to(device) + cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to( + device + ) + cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to( + device + ) + cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( + device + ) + cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( + device + ) + states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] return states def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: - return ScheduledFloat((0.0, x), - (20000.0, ratio * x), - default=x) + return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x) def _balancer_schedule(min_prob: float): @@ -542,31 +577,45 @@ class Zipformer2EncoderLayer(nn.Module): >>> pos_emb = torch.rand(32, 19, 512) >>> out = encoder_layer(src, pos_emb) """ + def __init__( - self, - embed_dim: int, - pos_dim: int, - num_heads: int, - query_head_dim: int, - pos_head_dim: int, - value_head_dim: int, - feedforward_dim: int, - dropout: FloatLike = 0.1, - cnn_module_kernel: int = 31, - causal: bool = False, - attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), - conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), - const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0), - ff2_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)), - ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)), - bypass_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0), + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + value_head_dim: int, + feedforward_dim: int, + dropout: FloatLike = 0.1, + cnn_module_kernel: int = 31, + causal: bool = False, + attention_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 + ), + conv_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 + ), + const_attention_rate: FloatLike = ScheduledFloat( + (0.0, 0.25), (4000.0, 0.025), default=0 + ), + ff2_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) + ), + ff3_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) + ), + bypass_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.5), (4000.0, 0.02), default=0 + ), ) -> None: super(Zipformer2EncoderLayer, self).__init__() self.embed_dim = embed_dim # self.bypass implements layer skipping as well as bypass; see its default values. - self.bypass = BypassModule(embed_dim, skip_rate=bypass_skip_rate, - straight_through_rate=0) + self.bypass = BypassModule( + embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0 + ) # bypass_mid is bypass used in the middle of the layer. self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0) @@ -584,39 +633,39 @@ class Zipformer2EncoderLayer(nn.Module): self.const_attention_rate = copy.deepcopy(const_attention_rate) self.self_attn_weights = RelPositionMultiheadAttentionWeights( - embed_dim, pos_dim=pos_dim, num_heads=num_heads, - query_head_dim=query_head_dim, pos_head_dim=pos_head_dim, + embed_dim, + pos_dim=pos_dim, + num_heads=num_heads, + query_head_dim=query_head_dim, + pos_head_dim=pos_head_dim, dropout=0.0, ) - self.self_attn1 = SelfAttention(embed_dim, num_heads, - value_head_dim) + self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim) - self.self_attn2 = SelfAttention(embed_dim, num_heads, - value_head_dim) + self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim) - self.feed_forward1 = FeedforwardModule(embed_dim, - (feedforward_dim * 3) // 4, - dropout) + self.feed_forward1 = FeedforwardModule( + embed_dim, (feedforward_dim * 3) // 4, dropout + ) - self.feed_forward2 = FeedforwardModule(embed_dim, - feedforward_dim, - dropout) + self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) - self.feed_forward3 = FeedforwardModule(embed_dim, - (feedforward_dim * 5) // 4, - dropout) + self.feed_forward3 = FeedforwardModule( + embed_dim, (feedforward_dim * 5) // 4, dropout + ) - self.nonlin_attention = NonlinAttention(embed_dim, - hidden_channels=3 * embed_dim // 4) + self.nonlin_attention = NonlinAttention( + embed_dim, hidden_channels=3 * embed_dim // 4 + ) - self.conv_module1 = ConvolutionModule(embed_dim, - cnn_module_kernel, - causal=causal) + self.conv_module1 = ConvolutionModule( + embed_dim, cnn_module_kernel, causal=causal + ) - self.conv_module2 = ConvolutionModule(embed_dim, - cnn_module_kernel, - causal=causal) + self.conv_module2 = ConvolutionModule( + embed_dim, cnn_module_kernel, causal=causal + ) # TODO: remove it self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) @@ -624,15 +673,20 @@ class Zipformer2EncoderLayer(nn.Module): self.norm = BiasNorm(embed_dim) self.balancer1 = Balancer( - embed_dim, channel_dim=-1, - min_positive=0.45, max_positive=0.55, - min_abs=0.2, max_abs=4.0, + embed_dim, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + min_abs=0.2, + max_abs=4.0, ) # balancer for output of NonlinAttentionModule self.balancer_na = Balancer( - embed_dim, channel_dim=-1, - min_positive=0.3, max_positive=0.7, + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), prob=0.05, # out of concern for memory usage ) @@ -641,34 +695,50 @@ class Zipformer2EncoderLayer(nn.Module): # small. give this a very small probability, even at the start of # training, it's to fix a rare problem and it's OK to fix it slowly. self.balancer_ff2 = Balancer( - embed_dim, channel_dim=-1, - min_positive=0.3, max_positive=0.7, + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0), max_abs=2.0, prob=0.05, ) self.balancer_ff3 = Balancer( - embed_dim, channel_dim=-1, - min_positive=0.3, max_positive=0.7, + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0), max_abs=4.0, prob=0.05, ) - self.whiten = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(4.0, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01) - - self.balancer2 = Balancer( - embed_dim, channel_dim=-1, - min_positive=0.45, max_positive=0.55, - min_abs=0.1, max_abs=4.0, + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(4.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, ) - def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]: - if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing(): + self.balancer2 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + min_abs=0.1, + max_abs=4.0, + ) + + def get_sequence_dropout_mask( + self, x: Tensor, dropout_rate: float + ) -> Optional[Tensor]: + if ( + dropout_rate == 0.0 + or not self.training + or torch.jit.is_scripting() + or torch.jit.is_tracing() + ): return None batch_size = x.shape[1] mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype) @@ -694,21 +764,21 @@ class Zipformer2EncoderLayer(nn.Module): src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: """ - Pass the input through the encoder layer. - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) - chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. + Pass the input through the encoder layer. + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. - Returns: - A tensor which has the same shape as src + Returns: + A tensor which has the same shape as src """ src_orig = src @@ -716,7 +786,9 @@ class Zipformer2EncoderLayer(nn.Module): if torch.jit.is_scripting() or torch.jit.is_tracing(): attention_skip_rate = 0.0 else: - attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0 + attention_skip_rate = ( + float(self.attention_skip_rate) if self.training else 0.0 + ) # attn_weights: (num_heads, batch_size, seq_len, seq_len) attn_weights = self.self_attn_weights( @@ -728,7 +800,9 @@ class Zipformer2EncoderLayer(nn.Module): src = src + self.feed_forward1(src) - self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate) + self_attn_dropout_mask = self.get_sequence_dropout_mask( + src, attention_skip_rate + ) selected_attn_weights = attn_weights[0:1] if torch.jit.is_scripting() or torch.jit.is_tracing(): @@ -739,53 +813,75 @@ class Zipformer2EncoderLayer(nn.Module): # averaging-over-time operation. # only need the mask, can just use the 1st one and expand later selected_attn_weights = selected_attn_weights[0:1] - selected_attn_weights = (selected_attn_weights > 0.0).to(selected_attn_weights.dtype) - selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)) + selected_attn_weights = (selected_attn_weights > 0.0).to( + selected_attn_weights.dtype + ) + selected_attn_weights = selected_attn_weights * ( + 1.0 / selected_attn_weights.sum(dim=-1, keepdim=True) + ) na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights)) - src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask) + src = src + ( + na if self_attn_dropout_mask is None else na * self_attn_dropout_mask + ) self_attn = self.self_attn1(src, attn_weights) - src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) + src = src + ( + self_attn + if self_attn_dropout_mask is None + else self_attn * self_attn_dropout_mask + ) if torch.jit.is_scripting() or torch.jit.is_tracing(): conv_skip_rate = 0.0 else: conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout(self.conv_module1(src, chunk_size=chunk_size, - src_key_padding_mask=src_key_padding_mask), - conv_skip_rate) + src = src + self.sequence_dropout( + self.conv_module1( + src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask + ), + conv_skip_rate, + ) if torch.jit.is_scripting() or torch.jit.is_tracing(): ff2_skip_rate = 0.0 else: ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout(self.balancer_ff2(self.feed_forward2(src)), - ff2_skip_rate) + src = src + self.sequence_dropout( + self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate + ) # bypass in the middle of the layer. src = self.bypass_mid(src_orig, src) self_attn = self.self_attn2(src, attn_weights) - src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) + src = src + ( + self_attn + if self_attn_dropout_mask is None + else self_attn * self_attn_dropout_mask + ) if torch.jit.is_scripting() or torch.jit.is_tracing(): conv_skip_rate = 0.0 else: conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout(self.conv_module2(src, chunk_size=chunk_size, - src_key_padding_mask=src_key_padding_mask), - conv_skip_rate) + src = src + self.sequence_dropout( + self.conv_module2( + src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask + ), + conv_skip_rate, + ) if torch.jit.is_scripting() or torch.jit.is_tracing(): ff3_skip_rate = 0.0 else: ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout(self.balancer_ff3(self.feed_forward3(src)), - ff3_skip_rate) + src = src + self.sequence_dropout( + self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate + ) src = self.balancer1(src) src = self.norm(src) @@ -929,20 +1025,22 @@ class Zipformer2Encoder(nn.Module): >>> src = torch.rand(10, 32, 512) >>> out = zipformer_encoder(src) """ + def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - pos_dim: int, - dropout: float, - warmup_begin: float, - warmup_end: float, - initial_layerdrop_rate: float = 0.5, - final_layerdrop_rate: float = 0.05, + self, + encoder_layer: nn.Module, + num_layers: int, + pos_dim: int, + dropout: float, + warmup_begin: float, + warmup_end: float, + initial_layerdrop_rate: float = 0.5, + final_layerdrop_rate: float = 0.05, ) -> None: super().__init__() - self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15, - length_factor=1.0) + self.encoder_pos = CompactRelPositionalEncoding( + pos_dim, dropout_rate=0.15, length_factor=1.0 + ) self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] @@ -951,13 +1049,15 @@ class Zipformer2Encoder(nn.Module): assert 0 <= warmup_begin <= warmup_end - delta = (1. / num_layers) * (warmup_end - warmup_begin) + delta = (1.0 / num_layers) * (warmup_end - warmup_begin) cur_begin = warmup_begin # interpreted as a training batch index for i in range(num_layers): cur_end = cur_begin + delta - self.layers[i].bypass.skip_rate = ScheduledFloat((cur_begin, initial_layerdrop_rate), - (cur_end, final_layerdrop_rate), - default=0.0) + self.layers[i].bypass.skip_rate = ScheduledFloat( + (cur_begin, initial_layerdrop_rate), + (cur_end, final_layerdrop_rate), + default=0.0, + ) cur_begin = cur_end def forward( @@ -1031,8 +1131,13 @@ class Zipformer2Encoder(nn.Module): new_states = [] for i, mod in enumerate(self.layers): ( - cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2 - ) = states[i * 6: (i + 1) * 6] + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ) = states[i * 6 : (i + 1) * 6] ( output, new_cached_key, @@ -1040,7 +1145,7 @@ class Zipformer2Encoder(nn.Module): new_cached_val1, new_cached_val2, new_cached_conv1, - new_cached_conv2 + new_cached_conv2, ) = mod.streaming_forward( output, pos_emb, @@ -1072,13 +1177,15 @@ class BypassModule(nn.Module): "straight-through", i.e. to not do the bypass operation much initially, in order to force all the modules to learn something. """ + def __init__( - self, - embed_dim: int, - skip_rate: FloatLike = 0.0, - straight_through_rate: FloatLike = 0.0, - scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), - scale_max: FloatLike = 1.0): + self, + embed_dim: int, + skip_rate: FloatLike = 0.0, + straight_through_rate: FloatLike = 0.0, + scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), + scale_max: FloatLike = 1.0, + ): super().__init__() self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) self.skip_rate = copy.deepcopy(skip_rate) @@ -1094,9 +1201,9 @@ class BypassModule(nn.Module): if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: return self.bypass_scale else: - ans = limit_param_value(self.bypass_scale, - min=float(self.scale_min), - max=float(self.scale_max)) + ans = limit_param_value( + self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max) + ) skip_rate = float(self.skip_rate) if skip_rate != 0.0: mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate @@ -1105,13 +1212,14 @@ class BypassModule(nn.Module): # on which we have randomly chosen to do layer-skipping. straight_through_rate = float(self.straight_through_rate) if straight_through_rate != 0.0: - mask = torch.rand((batch_size, 1), device=ans.device) < straight_through_rate + mask = ( + torch.rand((batch_size, 1), device=ans.device) + < straight_through_rate + ) ans = torch.maximum(ans, mask.to(ans.dtype)) return ans - def forward(self, - src_orig: Tensor, - src: Tensor): + def forward(self, src_orig: Tensor, src: Tensor): """ Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) Returns: something with the same shape as src and src_orig @@ -1126,15 +1234,13 @@ class DownsampledZipformer2Encoder(nn.Module): after convolutional downsampling, and then upsampled again at the output, and combined with the origin input, so that the output has the same shape as the input. """ - def __init__(self, - encoder: nn.Module, - dim: int, - downsample: int, - dropout: FloatLike): + + def __init__( + self, encoder: nn.Module, dim: int, downsample: int, dropout: FloatLike + ): super(DownsampledZipformer2Encoder, self).__init__() self.downsample_factor = downsample - self.downsample = SimpleDownsample(dim, - downsample, dropout) + self.downsample = SimpleDownsample(dim, downsample, dropout) self.num_layers = encoder.num_layers self.encoder = encoder self.upsample = SimpleUpsample(dim, downsample) @@ -1166,7 +1272,7 @@ class DownsampledZipformer2Encoder(nn.Module): src = self.downsample(src) ds = self.downsample_factor if attn_mask is not None: - attn_mask = attn_mask[::ds,::ds] + attn_mask = attn_mask[::ds, ::ds] src = self.encoder( src, @@ -1177,7 +1283,7 @@ class DownsampledZipformer2Encoder(nn.Module): ) src = self.upsample(src) # remove any extra frames that are not a multiple of downsample_factor - src = src[:src_orig.shape[0]] + src = src[: src_orig.shape[0]] return self.out_combiner(src_orig, src) @@ -1213,7 +1319,7 @@ class DownsampledZipformer2Encoder(nn.Module): ) src = self.upsample(src) # remove any extra frames that are not a multiple of downsample_factor - src = src[:src_orig.shape[0]] + src = src[: src_orig.shape[0]] return self.out_combiner(src_orig, src), new_states @@ -1222,10 +1328,8 @@ class SimpleDownsample(torch.nn.Module): """ Does downsampling with attention, by weighted sum, and a projection.. """ - def __init__(self, - channels: int, - downsample: int, - dropout: FloatLike): + + def __init__(self, channels: int, downsample: int, dropout: FloatLike): super(SimpleDownsample, self).__init__() self.bias = nn.Parameter(torch.zeros(downsample)) @@ -1235,8 +1339,7 @@ class SimpleDownsample(torch.nn.Module): self.downsample = downsample - def forward(self, - src: Tensor) -> Tensor: + def forward(self, src: Tensor) -> Tensor: """ x: (seq_len, batch_size, in_channels) Returns a tensor of shape @@ -1249,7 +1352,7 @@ class SimpleDownsample(torch.nn.Module): # Pad to an exact multiple of self.downsample # right-pad src, repeating the last element. pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2]) + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) src = torch.cat((src, src_extra), dim=0) assert src.shape[0] == d_seq_len * ds @@ -1270,14 +1373,12 @@ class SimpleUpsample(torch.nn.Module): A very simple form of upsampling that mostly just repeats the input, but also adds a position-specific bias. """ - def __init__(self, - num_channels: int, - upsample: int): + + def __init__(self, num_channels: int, upsample: int): super(SimpleUpsample, self).__init__() self.upsample = upsample - def forward(self, - src: Tensor) -> Tensor: + def forward(self, src: Tensor) -> Tensor: """ x: (seq_len, batch_size, num_channels) Returns a tensor of shape @@ -1315,11 +1416,13 @@ class CompactRelPositionalEncoding(torch.nn.Module): length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives less weight to small differences of offset near the origin. """ + def __init__( - self, embed_dim: int, - dropout_rate: FloatLike, - max_len: int = 1000, - length_factor: float = 1.0, + self, + embed_dim: int, + dropout_rate: FloatLike, + max_len: int = 1000, + length_factor: float = 1.0, ) -> None: """Construct a CompactRelPositionalEncoding object.""" super(CompactRelPositionalEncoding, self).__init__() @@ -1343,19 +1446,22 @@ class CompactRelPositionalEncoding(torch.nn.Module): return # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] - x = torch.arange(-(T-1), T, - device=x.device).to(torch.float32).unsqueeze(1) + x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1) freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution # for small time offsets but less resolution for large time offsets. - compression_length = (self.embed_dim ** 0.5) + compression_length = self.embed_dim**0.5 # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; # but it does so more slowly than T for large absolute values of T. # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which # is important. - x_compressed = compression_length * x.sign() * ((x.abs() + compression_length).log() - math.log(compression_length)) + x_compressed = ( + compression_length + * x.sign() + * ((x.abs() + compression_length).log() - math.log(compression_length)) + ) # if self.length_factor == 1.0, then length_scale is chosen so that the # FFT can exactly separate points close to the origin (T == 0). So this @@ -1397,7 +1503,7 @@ class CompactRelPositionalEncoding(torch.nn.Module): - x_size_left + 1 : self.pe.size(0) // 2 # noqa E203 + x.size(0), - : + :, ] pos_emb = pos_emb.unsqueeze(0) return self.dropout(pos_emb) @@ -1424,15 +1530,14 @@ class RelPositionMultiheadAttentionWeights(nn.Module): """ def __init__( - self, - embed_dim: int, - pos_dim: int, - num_heads: int, - query_head_dim: int, - pos_head_dim: int, - dropout: float = 0.0, - pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), - (4000.0, 0.0)) + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + dropout: float = 0.0, + pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)), ) -> None: super().__init__() self.embed_dim = embed_dim @@ -1451,13 +1556,16 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # dividing it between the query and key. Note: this module is intended # to be used with the ScaledAdam optimizer; with most other optimizers, # it would be necessary to apply the scaling factor in the forward function. - self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True, - initial_scale=query_head_dim**-0.25) + self.in_proj = ScaledLinear( + embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25 + ) - self.whiten_keys = Whiten(num_groups=num_heads, - whitening_limit=_whitening_schedule(3.0), - prob=(0.025, 0.25), - grad_scale=0.025) + self.whiten_keys = Whiten( + num_groups=num_heads, + whitening_limit=_whitening_schedule(3.0), + prob=(0.025, 0.25), + grad_scale=0.025, + ) # add a balancer for the keys that runs with very small probability, and # tries to enforce that all dimensions have mean around zero. The @@ -1467,19 +1575,20 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # bias because the small numerical roundoff tends to have a non-random # sign. This module is intended to prevent that. Use a very small # probability; that should be suffixient to fix the problem. - self.balance_keys = Balancer(key_head_dim * num_heads, - channel_dim=-1, - min_positive=0.4, - max_positive=0.6, - min_abs=0.0, - max_abs=100.0, - prob=0.025) + self.balance_keys = Balancer( + key_head_dim * num_heads, + channel_dim=-1, + min_positive=0.4, + max_positive=0.6, + min_abs=0.0, + max_abs=100.0, + prob=0.025, + ) # linear transformation for positional encoding. - self.linear_pos = ScaledLinear(pos_dim, - num_heads * pos_head_dim, - bias=False, - initial_scale=0.05) + self.linear_pos = ScaledLinear( + pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05 + ) # the following are for diagnosics only, see --print-diagnostics option self.copy_pos_query = Identity() @@ -1515,10 +1624,10 @@ class RelPositionMultiheadAttentionWeights(nn.Module): query_dim = query_head_dim * num_heads # self-attention - q = x[...,0:query_dim] - k = x[...,query_dim:2*query_dim] + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] # p is the position-encoding query - p = x[...,2*query_dim:] + p = x[..., 2 * query_dim :] assert p.shape[-1] == num_heads * pos_head_dim q = self.copy_query(q) # for diagnostics only, does nothing. @@ -1546,7 +1655,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module): if use_pos_scores: pos_emb = self.linear_pos(pos_emb) seq_len2 = 2 * seq_len - 1 - pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1) + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( + 2, 0, 3, 1 + ) # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) @@ -1565,12 +1676,16 @@ class RelPositionMultiheadAttentionWeights(nn.Module): pos_scores = torch.gather(pos_scores, dim=1, index=indexes) pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len) else: - pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len), - (pos_scores.stride(0), - pos_scores.stride(1), - pos_scores.stride(2)-pos_scores.stride(3), - pos_scores.stride(3)), - storage_offset=pos_scores.stride(3) * (seq_len - 1)) + pos_scores = pos_scores.as_strided( + (num_heads, batch_size, seq_len, seq_len), + ( + pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2) - pos_scores.stride(3), + pos_scores.stride(3), + ), + storage_offset=pos_scores.stride(3) * (seq_len - 1), + ) attn_scores = attn_scores + pos_scores @@ -1589,10 +1704,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # but we view this as a failsafe to avoid "implausible" parameter # values rather than a regularization method that should be active # under normal circumstances. - attn_scores = penalize_abs_values_gt(attn_scores, - limit=25.0, - penalty=1.0e-04, - name=self.name) + attn_scores = penalize_abs_values_gt( + attn_scores, limit=25.0, penalty=1.0e-04, name=self.name + ) assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) @@ -1605,7 +1719,10 @@ class RelPositionMultiheadAttentionWeights(nn.Module): attn_scores = attn_scores.masked_fill(attn_mask, -1000) if key_padding_mask is not None: - assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape + assert key_padding_mask.shape == ( + batch_size, + seq_len, + ), key_padding_mask.shape attn_scores = attn_scores.masked_fill( key_padding_mask.unsqueeze(1), -1000, @@ -1661,14 +1778,17 @@ class RelPositionMultiheadAttentionWeights(nn.Module): query_dim = query_head_dim * num_heads # self-attention - q = x[...,0:query_dim] - k = x[...,query_dim:2*query_dim] + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] # p is the position-encoding query - p = x[...,2*query_dim:] + p = x[..., 2 * query_dim :] assert p.shape[-1] == num_heads * pos_head_dim # Pad cached left contexts - assert cached_key.shape[0] == left_context_len, (cached_key.shape[0], left_context_len) + assert cached_key.shape[0] == left_context_len, ( + cached_key.shape[0], + left_context_len, + ) k = torch.cat([cached_key, k], dim=0) # Update cached left contexts cached_key = k[-left_context_len:, ...] @@ -1689,13 +1809,15 @@ class RelPositionMultiheadAttentionWeights(nn.Module): pos_emb = self.linear_pos(pos_emb) seq_len2 = 2 * seq_len - 1 + left_context_len - pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1) + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( + 2, 0, 3, 1 + ) # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) # [where seq_len2 represents relative position.] pos_scores = torch.matmul(p, pos_emb) - + if torch.jit.is_tracing(): (num_heads, batch_size, time1, n) = pos_scores.shape rows = torch.arange(start=time1 - 1, end=-1, step=-1) @@ -1709,16 +1831,25 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # to absolute position. I don't know whether I might have got the time-offsets backwards or # not, but let this code define which way round it is supposed to be. else: - pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, k_len), - (pos_scores.stride(0), - pos_scores.stride(1), - pos_scores.stride(2)-pos_scores.stride(3), - pos_scores.stride(3)), - storage_offset=pos_scores.stride(3) * (seq_len - 1)) + pos_scores = pos_scores.as_strided( + (num_heads, batch_size, seq_len, k_len), + ( + pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2) - pos_scores.stride(3), + pos_scores.stride(3), + ), + storage_offset=pos_scores.stride(3) * (seq_len - 1), + ) attn_scores = attn_scores + pos_scores - assert attn_scores.shape == (num_heads, batch_size, seq_len, k_len), attn_scores.shape + assert attn_scores.shape == ( + num_heads, + batch_size, + seq_len, + k_len, + ), attn_scores.shape if key_padding_mask is not None: assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape @@ -1731,18 +1862,21 @@ class RelPositionMultiheadAttentionWeights(nn.Module): return attn_weights, cached_key - def _print_attn_entropy( - self, - attn_weights: Tensor): + def _print_attn_entropy(self, attn_weights: Tensor): # attn_weights: (num_heads, batch_size, seq_len, seq_len) (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): with torch.cuda.amp.autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) - attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum( - dim=-1).mean(dim=(1,2)) - logging.info(f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}") + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .mean(dim=(1, 2)) + ) + logging.info( + f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" + ) class SelfAttention(nn.Module): @@ -1755,25 +1889,26 @@ class SelfAttention(nn.Module): num_heads: the number of attention heads value_head_dim: the value dimension per head """ + def __init__( - self, - embed_dim: int, - num_heads: int, - value_head_dim: int, + self, + embed_dim: int, + num_heads: int, + value_head_dim: int, ) -> None: super().__init__() - self.in_proj = nn.Linear(embed_dim, - num_heads * value_head_dim, - bias=True) + self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True) - self.out_proj = ScaledLinear(num_heads * value_head_dim, - embed_dim, bias=True, - initial_scale=0.05) + self.out_proj = ScaledLinear( + num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05 + ) - self.whiten = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(7.5, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01) + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) def forward( self, @@ -1802,8 +1937,11 @@ class SelfAttention(nn.Module): x = torch.matmul(attn_weights, x) # v: (num_heads, batch_size, seq_len, value_head_dim) - x = x.permute(2, 1, 0, 3).contiguous().view( - seq_len, batch_size, num_heads * value_head_dim) + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) # returned value is of shape (seq_len, batch_size, embed_dim), like the input. x = self.out_proj(x) @@ -1840,7 +1978,10 @@ class SelfAttention(nn.Module): x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) # Pad cached left contexts - assert cached_val.shape[0] == left_context_len, (cached_val.shape[0], left_context_len) + assert cached_val.shape[0] == left_context_len, ( + cached_val.shape[0], + left_context_len, + ) x = torch.cat([cached_val, x], dim=0) # Update cached left contexts cached_val = x[-left_context_len:, ...] @@ -1853,8 +1994,11 @@ class SelfAttention(nn.Module): x = torch.matmul(attn_weights, x) # v: (num_heads, batch_size, seq_len, value_head_dim) - x = x.permute(2, 1, 0, 3).contiguous().view( - seq_len, batch_size, num_heads * value_head_dim) + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) # returned value is of shape (seq_len, batch_size, embed_dim), like the input. x = self.out_proj(x) @@ -1863,33 +2007,38 @@ class SelfAttention(nn.Module): class FeedforwardModule(nn.Module): - """Feedforward module in Zipformer2 model. - """ - def __init__(self, - embed_dim: int, - feedforward_dim: int, - dropout: FloatLike): + """Feedforward module in Zipformer2 model.""" + + def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): super(FeedforwardModule, self).__init__() self.in_proj = nn.Linear(embed_dim, feedforward_dim) - self.hidden_balancer = Balancer(feedforward_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=1.0, - min_abs=0.75, - max_abs=5.0) + self.hidden_balancer = Balancer( + feedforward_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=1.0, + min_abs=0.75, + max_abs=5.0, + ) # shared_dim=0 means we share the dropout mask along the time axis - self.out_proj = ActivationDropoutAndLinear(feedforward_dim, embed_dim, - activation='SwooshL', - dropout_p=dropout, - dropout_shared_dim=0, bias=True, - initial_scale=0.1) + self.out_proj = ActivationDropoutAndLinear( + feedforward_dim, + embed_dim, + activation="SwooshL", + dropout_p=dropout, + dropout_shared_dim=0, + bias=True, + initial_scale=0.1, + ) - self.out_whiten = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(7.5), - prob=(0.025, 0.25), - grad_scale=0.01) + self.out_whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01, + ) def forward(self, x: Tensor): x = self.in_proj(x) @@ -1910,9 +2059,9 @@ class NonlinAttention(nn.Module): """ def __init__( - self, - channels: int, - hidden_channels: int, + self, + channels: int, + hidden_channels: int, ) -> None: super().__init__() @@ -1925,7 +2074,8 @@ class NonlinAttention(nn.Module): # starting from about 3, and poorly-trained instances of the module have smaller abs values # before the sigmoid. self.balancer = Balancer( - hidden_channels, channel_dim=-1, + hidden_channels, + channel_dim=-1, min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)), max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)), min_abs=0.5, @@ -1937,19 +2087,23 @@ class NonlinAttention(nn.Module): self.identity2 = Identity() # for diagnostics. self.identity3 = Identity() # for diagnostics. - self.out_proj = ScaledLinear(hidden_channels, channels, - bias=True, - initial_scale=0.05) + self.out_proj = ScaledLinear( + hidden_channels, channels, bias=True, initial_scale=0.05 + ) - self.whiten1 = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(5.0), - prob=(0.025, 0.25), - grad_scale=0.01) + self.whiten1 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) - self.whiten2 = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(5.0, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01) + self.whiten2 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) def forward( self, @@ -1957,11 +2111,11 @@ class NonlinAttention(nn.Module): attn_weights: Tensor, ) -> Tensor: """. - Args: - x: a Tensor of shape (seq_len, batch_size, num_channels) -attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) - Returns: - a Tensor with the same shape as x + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + Returns: + a Tensor with the same shape as x """ x = self.in_proj(x) @@ -2031,13 +2185,21 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) (seq_len, batch_size, embed_dim) = x.shape num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, seq_len, left_context_len + seq_len) + assert attn_weights.shape == ( + num_heads, + batch_size, + seq_len, + left_context_len + seq_len, + ) x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) # now x: (num_heads, batch_size, seq_len, head_dim) # Pad cached tensor - assert cached_x.shape[2] == left_context_len, (cached_x.shape[2], left_context_len) + assert cached_x.shape[2] == left_context_len, ( + cached_x.shape[2], + left_context_len, + ) x_pad = torch.cat([cached_x, x], dim=2) # Update cached tensor cached_x = x_pad[:, :, -left_context_len:, :] @@ -2062,8 +2224,12 @@ class ConvolutionModule(nn.Module): bias (bool): Whether to use bias in conv layers (default=True). """ + def __init__( - self, channels: int, kernel_size: int, causal: bool, + self, + channels: int, + kernel_size: int, + causal: bool, ) -> None: """Construct a ConvolutionModule object.""" super(ConvolutionModule, self).__init__() @@ -2074,7 +2240,8 @@ class ConvolutionModule(nn.Module): self.causal = causal self.in_proj = nn.Linear( - channels, 2 * bottleneck_dim, + channels, + 2 * bottleneck_dim, ) # the gradients on in_proj are a little noisy, likely to do with the # sigmoid in glu. @@ -2093,7 +2260,8 @@ class ConvolutionModule(nn.Module): # it will be in a better position to start learning something, i.e. to latch onto # the correct range. self.balancer1 = Balancer( - bottleneck_dim, channel_dim=-1, + bottleneck_dim, + channel_dim=-1, min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)), max_positive=1.0, min_abs=1.5, @@ -2108,31 +2276,40 @@ class ConvolutionModule(nn.Module): assert kernel_size % 2 == 1 - self.depthwise_conv = ChunkCausalDepthwiseConv1d( - channels=bottleneck_dim, - kernel_size=kernel_size) if causal else nn.Conv1d( - in_channels=bottleneck_dim, - out_channels=bottleneck_dim, - groups=bottleneck_dim, - kernel_size=kernel_size, - padding=kernel_size // 2) + self.depthwise_conv = ( + ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size) + if causal + else nn.Conv1d( + in_channels=bottleneck_dim, + out_channels=bottleneck_dim, + groups=bottleneck_dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + ) self.balancer2 = Balancer( - bottleneck_dim, channel_dim=1, + bottleneck_dim, + channel_dim=1, min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), max_positive=1.0, min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)), max_abs=10.0, ) - self.whiten = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(7.5), - prob=(0.025, 0.25), - grad_scale=0.01) + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01, + ) self.out_proj = ActivationDropoutAndLinear( - bottleneck_dim, channels, activation='SwooshR', - dropout_p=0.0, initial_scale=0.05, + bottleneck_dim, + channels, + activation="SwooshR", + dropout_p=0.0, + initial_scale=0.05, ) def forward( @@ -2170,9 +2347,15 @@ class ConvolutionModule(nn.Module): if src_key_padding_mask is not None: x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - if not torch.jit.is_scripting() and not torch.jit.is_tracing() and chunk_size >= 0: + if ( + not torch.jit.is_scripting() + and not torch.jit.is_tracing() + and chunk_size >= 0 + ): # Not support exporting a model for simulated streaming decoding - assert self.causal, "Must initialize model with causal=True if you use chunk_size" + assert ( + self.causal + ), "Must initialize model with causal=True if you use chunk_size" x = self.depthwise_conv(x, chunk_size=chunk_size) else: x = self.depthwise_conv(x) @@ -2242,10 +2425,12 @@ def _test_zipformer_main(causal: bool = False): # Just make sure the forward pass runs. c = Zipformer2( - encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4), + encoder_dim=(64, 96), + encoder_unmasked_dim=(48, 64), + num_heads=(4, 4), causal=causal, chunk_size=(4,) if causal else (-1,), - left_context_frames=(64,) + left_context_frames=(64,), ) batch_size = 5 seq_len = 20