diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 1bf4593d4..ae7caa860 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -105,59 +105,81 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--num-encoder-layers", type=str, default="2,4,3,2,2,4", - help="Number of zipformer encoder layers, comma separated.", + help="Number of zipformer encoder layers per stack, comma separated.", ) - parser.add_argument( - "--feedforward-dims", - type=str, - default="1024,1024,1536,1536,1536,1024", - help="Feedforward dimension of the zipformer encoder layers, comma separated.", - ) parser.add_argument( - "--nhead", - type=str, - default="8,8,8,8,8,8", - help="Number of attention heads in the zipformer encoder layers.", - ) - - parser.add_argument( - "--encoder-dims", - type=str, - default="384,384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated" - ) - - parser.add_argument( - "--attention-dims", - type=str, - default="192,192,192,192,192,192", - help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""" - ) - - parser.add_argument( - "--encoder-unmasked-dims", - type=str, - default="256,256,256,256,256,256", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse." - ) - - parser.add_argument( - "--zipformer-downsampling-factors", + "--downsampling-factor", type=str, default="1,2,4,8,4,2", help="Downsampling factor for each stack of encoder layers.", ) + parser.add_argument( - "--cnn-module-kernels", + "--feedforward-dim", type=str, - default="31,31,31,31,31,31", - help="Sizes of kernels in convolution modules", + default="1024,1024,1536,1536,1536,1024", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="8", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="384", + help="Embedding dimension in encoder stacks: a single int or comma-separated list." + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="24", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list." + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list." + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list." + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="192", + help="Positional-encoding embedding dimension" + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim." + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", ) parser.add_argument( @@ -455,14 +477,19 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: encoder = Zipformer( num_features=params.feature_dim, output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple(params.zipformer_downsampling_factors), - encoder_dims=to_int_tuple(params.encoder_dims), - attention_dim=to_int_tuple(params.attention_dims), - encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), - nhead=to_int_tuple(params.nhead), - feedforward_dim=to_int_tuple(params.feedforward_dims), - cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + downsampling_factor=to_int_tuple(params.downsampling_factor), num_encoder_layers=to_int_tuple(params.num_encoder_layers), + encoder_dim=to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=to_int_tuple(params.query_head_dim), + pos_head_dim=to_int_tuple(params.pos_head_dim), + value_head_dim=to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=to_int_tuple(params.num_heads), + feedforward_dim=to_int_tuple(params.feedforward_dim), + cnn_module_kernel=to_int_tuple(params.cnn_module_kernel), + dropout=0.1, + warmup_batches=4000.0, ) return encoder @@ -479,7 +506,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(',')[-1]), + encoder_dim=int(params.encoder_dim.split(',')[-1]), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, @@ -496,7 +523,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: encoder=encoder, decoder=decoder, joiner=joiner, - encoder_dim=int(params.encoder_dims.split(',')[-1]), + encoder_dim=int(params.encoder_dim.split(',')[-1]), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 93d8a43bc..138c7409a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -47,73 +47,120 @@ from icefall.dist import get_rank class Zipformer(EncoderInterface): """ Args: - num_features (int): Number of input features - d_model: (int,int): embedding dimension of 2 encoder stacks - attention_dim: (int,int): attention dimension of 2 encoder stacks - nhead (int, int): number of heads - dim_feedforward (int, int): feedforward dimention in 2 encoder stacks - num_encoder_layers (int): number of encoder layers + + Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length + as downsampling_factor if they are single ints or one-element tuples. The length of + downsampling_factor defines the number of stacks. + + + num_features (int): Number of input features, e.g. 40. + output_downsampling_factor (int): how much to downsample at the output. Note: + we also downsample by a factor of 2 in the Conv2dSubsampling encoder. + You should probably leave this at 2. + downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. + Note: this is in addition to the downsampling factor of 2 that is applied in + the frontend (self.encoder_embed). + encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per + encoder stack. + num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack + encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of + the encoder stacks for purposes of per-frame dropout (recommend 256 for + now). + query_head_dim (int or Tuple[int]): dimension of query and key per attention + head: per stack, if a tuple.. + value_head_dim (int or Tuple[int]): dimension of value in each attention head + pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per + attention head + num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. + Must be at least 4. + feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules + cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module + + pos_dim (int): the dimension of each positional-encoding vector prior to projection, + e.g. 128. + dropout (float): dropout rate - cnn_module_kernel (int): Kernel size of convolution module - vgg_frontend (bool): whether to use vgg frontend. - warmup_batches (float): number of batches to warm up over + warmup_batches (float): number of batches to warm up over; this controls + dropout of encoder layers. """ def __init__( - self, - num_features: int, - output_downsampling_factor: int = 2, - encoder_dims: Tuple[int] = (384, 384), - attention_dim: Tuple[int] = (256, 256), - encoder_unmasked_dims: Tuple[int] = (256, 256), - zipformer_downsampling_factors: Tuple[int] = (2, 4), - nhead: Tuple[int] = (8, 8), - feedforward_dim: Tuple[int] = (1536, 2048), - num_encoder_layers: Tuple[int] = (12, 12), - dropout: float = 0.1, - cnn_module_kernels: Tuple[int] = (31, 31), - pos_dim: int = 4, - warmup_batches: float = 4000.0, + self, + num_features: int, + output_downsampling_factor: int = 2, + downsampling_factor: Tuple[int] = (2, 4), + encoder_dim: Union[int, Tuple[int]] = 384, + num_encoder_layers: Union[int, Tuple[int]] = 4, + encoder_unmasked_dim: Union[int, Tuple[int]] = 256, + query_head_dim: Union[int, Tuple[int]] = 24, + pos_head_dim: Union[int, Tuple[int]] = 4, + value_head_dim: Union[int, Tuple[int]] = 12, + num_heads: Union[int, Tuple[int]] = 8, + feedforward_dim: Union[int, Tuple[int]] = 1536, + cnn_module_kernel: Union[int, Tuple[int]] = 31, + pos_dim: int = 192, + dropout: float = 0.1, + warmup_batches: float = 4000.0, ) -> None: super(Zipformer, self).__init__() - self.num_features = num_features - self.encoder_unmasked_dims = encoder_unmasked_dims - assert 0 < encoder_dims[0] <= encoder_dims[1] - self.encoder_dims = encoder_dims - self.encoder_unmasked_dims = encoder_unmasked_dims - self.zipformer_downsampling_factors = zipformer_downsampling_factors - self.output_downsampling_factor = output_downsampling_factor + def _to_tuple(x): + """ Converts a single int or a 1-tuple of an int to a tuple with the same length as output_downsampling_factor""" + if isinstance(x, int): + x = (x,) + if len(x) == 1: + x = x * len(downsampling_factor) + else: + assert len(x) == len(downsampling_factor) and isinstance(x[0], int) + return x - # will be written to, see set_batch_count() + self.num_features = num_features # int + self.output_downsampling_factor = output_downsampling_factor # int + self.downsampling_factor = downsampling_factor # tuple + self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple + self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(encoder_unmasked_dim) # tuple + num_encoder_layers = _to_tuple(num_encoder_layers) + query_head_dim = _to_tuple(query_head_dim) + value_head_dim = _to_tuple(value_head_dim) + pos_head_dim = _to_tuple(pos_head_dim) + num_heads = _to_tuple(num_heads) + feedforward_dim = _to_tuple(feedforward_dim) + cnn_module_kernel = _to_tuple(cnn_module_kernel) + + # will be written to in training loop, see set_batch_count() self.batch_count = 0 self.warmup_end = warmup_batches - for u,d in zip(encoder_unmasked_dims, encoder_dims): + for u,d in zip(encoder_unmasked_dim, encoder_dim): assert u <= d # self.encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, T//subsampling_factor, encoder_dims). + # to the shape (N, (T - 7) // 2, encoder_dims). # That is, it does two things simultaneously: - # (1) subsampling: T -> T//2 + # (1) subsampling: T -> (T - 7) // 2 # (2) embedding: num_features -> encoder_dims - self.encoder_embed = Conv2dSubsampling(num_features, encoder_dims[0], + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + self.encoder_embed = Conv2dSubsampling(num_features, encoder_dim[0], dropout=dropout) # each one will be ZipformerEncoder or DownsampledZipformerEncoder encoders = [] - num_encoders = len(encoder_dims) + num_encoders = len(downsampling_factor) for i in range(num_encoders): encoder_layer = ZipformerEncoderLayer( - encoder_dims[i], - attention_dim[i], - nhead[i], + encoder_dim[i], + pos_dim, + num_heads[i], + query_head_dim[i], + pos_head_dim[i], + value_head_dim[i], feedforward_dim[i], dropout, - cnn_module_kernels[i], - pos_dim, + cnn_module_kernel[i], ) # For the segment of the warmup period, we let the Conv2dSubsampling @@ -121,17 +168,18 @@ class Zipformer(EncoderInterface): encoder = ZipformerEncoder( encoder_layer, num_encoder_layers[i], + pos_dim, dropout, warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), warmup_end=warmup_batches * (i + 2) / (num_encoders + 1) ) - if zipformer_downsampling_factors[i] != 1: + if downsampling_factor[i] != 1: encoder = DownsampledZipformerEncoder( encoder, - input_dim=encoder_dims[i-1] if i > 0 else encoder_dims[0], - output_dim=encoder_dims[i], - downsample=zipformer_downsampling_factors[i], + input_dim=encoder_dim[i-1] if i > 0 else encoder_dim[0], + output_dim=encoder_dim[i], + downsample=downsampling_factor[i], ) encoders.append(encoder) self.encoders = nn.ModuleList(encoders) @@ -139,8 +187,8 @@ class Zipformer(EncoderInterface): # initializes self.skip_layers and self.skip_modules self._init_skip_modules() - self.downsample_output = AttentionDownsample(encoder_dims[-1], - encoder_dims[-1], + self.downsample_output = AttentionDownsample(encoder_dim[-1], + encoder_dim[-1], downsample=output_downsampling_factor) @@ -157,14 +205,14 @@ class Zipformer(EncoderInterface): def _init_skip_modules(self): """ - If self.zipformer_downampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer + If self.downampling_factor = (1, 2, 4, 8, 4, 2), then at the input of layer indexed 4 (in zero indexing), with has subsapling_factor=4, we combine the output of layers 2 and 3; and at the input of layer indexed 5, which which has subsampling_factor=2, we combine the outputs of layers 1 and 5. """ skip_layers = [] skip_modules = [] - z = self.zipformer_downsampling_factors + z = self.downsampling_factor for i in range(len(z)): if i <= 1 or z[i-1] <= z[i]: skip_layers.append(None) @@ -175,11 +223,11 @@ class Zipformer(EncoderInterface): if z[j] <= z[i] or j == 0: # TEMP logging statement. logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " - f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}.") + f"combine the outputs of layers {j} and {i-1}, with downsampling_factor={z[j]} and {z[i-1]}.") skip_layers.append(j) - skip_modules.append(SimpleCombiner(self.encoder_dims[j], - self.encoder_dims[i-1], - min_weight=(0.0,0.25))) + skip_modules.append(SimpleCombiner(self.encoder_dim[j], + self.encoder_dim[i-1], + min_weight=(0.0, 0.25))) break self.skip_layers = skip_layers self.skip_modules = nn.ModuleList(skip_modules) @@ -202,20 +250,18 @@ class Zipformer(EncoderInterface): x: the embeddings (needed for the shape and dtype and device), of shape (num_frames, batch_size, encoder_dims0) """ - num_encoders = len(self.encoder_dims) + num_encoders = len(self.encoder_dim) if not self.training: return [ 1.0 ] * num_encoders (num_frames0, batch_size, _encoder_dims0) = x.shape + assert self.encoder_dim[0] == _encoder_dims0 - assert self.encoder_dims[0] == _encoder_dims0 - - max_downsampling_factor = max(self.zipformer_downsampling_factors) + max_downsampling_factor = max(self.downsampling_factor) num_frames_max = (num_frames0 + max_downsampling_factor - 1) - feature_mask_dropout_prob = 0.15 # frame_mask_max shape: (num_frames_max, batch_size, 1) @@ -225,7 +271,7 @@ class Zipformer(EncoderInterface): feature_masks = [] for i in range(num_encoders): - ds = self.zipformer_downsampling_factors[i] + ds = self.downsampling_factor[i] upsample_factor = (max_downsampling_factor // ds) frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor, @@ -233,9 +279,9 @@ class Zipformer(EncoderInterface): .reshape(num_frames_max * upsample_factor, batch_size, 1)) num_frames = (num_frames0 + ds - 1) // ds frame_mask = frame_mask[:num_frames] - feature_mask = torch.ones(num_frames, batch_size, self.encoder_dims[i], + feature_mask = torch.ones(num_frames, batch_size, self.encoder_dim[i], dtype=x.dtype, device=x.device) - u = self.encoder_unmasked_dims[i] + u = self.encoder_unmasked_dim[i] feature_mask[:, :, u:] *= frame_mask feature_masks.append(feature_mask) @@ -254,7 +300,7 @@ class Zipformer(EncoderInterface): `x` before padding. Returns: Return a tuple containing 2 tensors: - - embeddings: its shape is (batch_size, output_seq_len, encoder_dims[-1]) + - embeddings: its shape is (batch_size, output_seq_len, encoder_dim[-1]) - lengths, a tensor of shape (batch_size,) containing the number of frames in `embeddings` before padding. """ @@ -272,7 +318,7 @@ class Zipformer(EncoderInterface): feature_masks = self.get_feature_masks(x) for i, module in enumerate(self.encoders): - ds = self.zipformer_downsampling_factors[i] + ds = self.downsampling_factor[i] if self.skip_layers[i] is not None: layer_skip_dropout_prob = self._get_layer_skip_dropout_prob() if (not self.training) or random.random() > layer_skip_dropout_prob: @@ -294,74 +340,79 @@ class Zipformer(EncoderInterface): class ZipformerEncoderLayer(nn.Module): """ - ZipformerEncoderLayer is made up of self-attn, feedforward and convolution networks. - See: "Zipformer: Convolution-augmented Transformer for Speech Recognition" - Args: - d_model: the number of expected features in the input (required). + embed_dim: the number of expected features in the input (required). nhead: the number of heads in the multiheadattention models (required). feedforward_dim: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). cnn_module_kernel (int): Kernel size of convolution module. Examples:: - >>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8) + >>> encoder_layer = ZipformerEncoderLayer(embed_dim=512, nhead=8) >>> src = torch.rand(10, 32, 512) >>> pos_emb = torch.rand(32, 19, 512) >>> out = encoder_layer(src, pos_emb) """ def __init__( self, - d_model: int, - attention_dim: int, - nhead: int, - feedforward_dim: int = 2048, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + value_head_dim: int, + feedforward_dim: int, dropout: float = 0.1, cnn_module_kernel: int = 31, - pos_dim: int = 4, ) -> None: super(ZipformerEncoderLayer, self).__init__() + self.embed_dim = embed_dim - self.d_model = d_model - - # will be written to, see set_batch_count() + # will be written to in training loop, see set_batch_count() self.batch_count = 0 - self.self_attn = RelPositionMultiheadAttention( - d_model, attention_dim, nhead, pos_dim, dropout=0.0, + self.self_attn_weights = RelPositionMultiheadAttentionWeights( + embed_dim, pos_dim, num_heads, + query_head_dim, pos_head_dim, dropout=0.0, ) - self.feed_forward1 = FeedforwardModule(d_model, + self.self_attn1 = SelfAttention(embed_dim, num_heads, + value_head_dim) + + self.self_attn2 = SelfAttention(embed_dim, num_heads, + value_head_dim) + + self.feed_forward1 = FeedforwardModule(embed_dim, feedforward_dim, dropout) - self.feed_forward2 = FeedforwardModule(d_model, + self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) - self.feed_forward3 = FeedforwardModule(d_model, + self.feed_forward3 = FeedforwardModule(embed_dim, feedforward_dim, dropout) - #self.conv_module1 = ConvolutionModule(d_model, + #self.conv_module1 = ConvolutionModule(embed_dim, #cnn_module_kernel) - self.nonlin_attention_module = NonlinAttentionModule(d_model) + self.nonlin_attention_module = NonlinAttentionModule(embed_dim) - self.conv_module = ConvolutionModule(d_model, - cnn_module_kernel) + self.conv_module = ConvolutionModule(embed_dim, + cnn_module_kernel) - self.squeeze_excite1 = ModifiedSEModule(d_model) - self.squeeze_excite2 = ModifiedSEModule(d_model) + self.attention_squeeze1 = AttentionSqueeze(embed_dim) + self.attention_squeeze2 = AttentionSqueeze(embed_dim) - self.norm_final = BasicNorm(d_model) + self.norm_final = BasicNorm(embed_dim) self.bypass_scale = nn.Parameter(torch.tensor(0.5)) # try to ensure the output is close to zero-mean (or at least, zero-median). self.balancer = ActivationBalancer( - d_model, channel_dim=-1, + embed_dim, channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0, ) @@ -435,41 +486,45 @@ class ZipformerEncoderLayer(nn.Module): dynamic_dropout = self.get_dynamic_dropout_rate() # multi-headed self-attention module + # TODO: make the various attention-using models be dropped + # out independently. use_self_attn = (random.random() > dynamic_dropout) if torch.jit.is_scripting() or use_self_attn: - src_att, attn_weights = self.self_attn( + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights = self.self_attn_weights( src, pos_emb=pos_emb, attn_mask=src_mask, key_padding_mask=src_key_padding_mask, ) - src = src + src_att + + if torch.jit.is_scripting() or use_self_attn: + src = src + self.self_attn1( + src, attn_weights) # convolution module - if torch.jit.is_scripting() or random.random() > dynamic_dropout: - src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) + if torch.jit.is_scripting() or use_self_attn: + src = src + self.nonlin_attention_module(src, + attn_weights[0:1]) src = src + self.feed_forward2(src) # pooling module if torch.jit.is_scripting() or use_self_attn: - src = src + self.squeeze_excite1(src, attn_weights, head_idx=0) + src = src + self.attention_squeeze1(src, attn_weights[1:2]) if torch.jit.is_scripting() or use_self_attn: - self_attn_output2 = self.self_attn.forward2(src, attn_weights) - src = src + self_attn_output2 + src = src + self.self_attn2( + src, attn_weights) - # attention version of convolution module - if torch.jit.is_scripting() or use_self_attn: - src = src + self.nonlin_attention_module(src, - attn_weights, - head_idx=1) + if torch.jit.is_scripting() or random.random() > dynamic_dropout: + src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = src + self.feed_forward3(src) # pooling module if torch.jit.is_scripting() or use_self_attn: - src = src + self.squeeze_excite2(src, attn_weights, head_idx=2) + src = src + self.attention_squeeze2(src, attn_weights[2:3]) src = self.norm_final(self.balancer(src)) @@ -487,9 +542,10 @@ class ZipformerEncoder(nn.Module): Args: encoder_layer: an instance of the ZipformerEncoderLayer() class (required). num_layers: the number of sub-encoder-layers in the encoder (required). + pos_dim: the dimension for the relative positional encoding Examples:: - >>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8) + >>> encoder_layer = ZipformerEncoderLayer(embed_dim=512, nhead=8) >>> zipformer_encoder = ZipformerEncoder(encoder_layer, num_layers=6) >>> src = torch.rand(10, 32, 512) >>> out = zipformer_encoder(src) @@ -498,6 +554,7 @@ class ZipformerEncoder(nn.Module): self, encoder_layer: nn.Module, num_layers: int, + pos_dim: int, dropout: float, warmup_begin: float, warmup_end: float @@ -514,8 +571,7 @@ class ZipformerEncoder(nn.Module): # so that we can keep this consistent across worker tasks (for efficiency). self.module_seed = torch.randint(0, 1000, ()).item() - self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, - dropout) + self.encoder_pos = RelPositionalEncoding(pos_dim, dropout) self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] @@ -916,18 +972,18 @@ class RelPositionalEncoding(torch.nn.Module): Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py Args: - d_model: Embedding dimension. + embed_dim: Embedding dimension. dropout_rate: Dropout rate. max_len: Maximum input length. """ def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 + self, embed_dim: int, dropout_rate: float, max_len: int = 5000 ) -> None: """Construct a PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() - self.d_model = d_model + self.embed_dim = embed_dim self.dropout = torch.nn.Dropout(dropout_rate) self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, max_len)) @@ -947,12 +1003,12 @@ class RelPositionalEncoding(torch.nn.Module): # Suppose `i` means to the position of query vecotr and `j` means the # position of key vector. We use position relative positions when keys # are to the left (i>j) and negative relative positions otherwise (i>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) - >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) + embed_dim: number of channels at the input to this module, e.g. 256 + pos_dim: dimension of the positional encoding vectors, e.g. 128. + num_heads: number of heads to compute weights for, e.g. 8 + query_head_dim: dimension of the query (and key), per head. e.g. 24. + pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. + dropout: dropout probability for attn_output_weights. Default: 0.0. """ def __init__( - self, - embed_dim: int, - attention_dim: int, - num_heads: int, - pos_dim: int, - dropout: float = 0.0, + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + dropout: float = 0.0, ) -> None: - super(RelPositionMultiheadAttention, self).__init__() + super().__init__() self.embed_dim = embed_dim - self.attention_dim = attention_dim self.num_heads = num_heads + self.query_head_dim = query_head_dim + self.pos_head_dim = pos_head_dim self.dropout = dropout - self.head_dim = attention_dim // num_heads - self.pos_dim = pos_dim - assert self.head_dim % 2 == 0, self.head_dim - assert ( - self.head_dim * num_heads == attention_dim - ) + + + key_head_dim = query_head_dim + in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads # the initial_scale is supposed to take over the "scaling" factor of - # head_dim ** -0.5, dividing it between the query and key. - in_proj_dim = (2 * attention_dim + # query, key - attention_dim // 2 + # value - pos_dim * num_heads) # positional encoding query - + # head_dim ** -0.5 that has been used in previous forms of attention, + # dividing it between the query and key. Note: this module is intended + # to be used with the ScaledAdam optimizer; with most other optimizers, + # it would be necessary to apply the scaling factor in the forward function. self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True, - initial_scale=self.head_dim**-0.25) + initial_scale=query_head_dim**-0.25) - # self.whiten_values is applied on the values in forward(); - # it just copies the keys but prevents low-rank distribution by modifying grads. - self.whiten_values = Whiten(num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025) + # .. TODO: tune this limit? whitening_limit. self.whiten_keys = Whiten(num_groups=num_heads, whitening_limit=2.0, prob=(0.025, 0.25), @@ -1052,7 +1100,9 @@ class RelPositionMultiheadAttention(nn.Module): # linear transformation for positional encoding. - self.linear_pos = ScaledLinear(embed_dim, num_heads * pos_dim, bias=False, + self.linear_pos = ScaledLinear(pos_dim, + num_heads * pos_head_dim, + bias=False, initial_scale=0.05) # this is to stop a failure mode where the output gets very small and is @@ -1071,26 +1121,6 @@ class RelPositionMultiheadAttention(nn.Module): self.copy_pos_query = Identity() self.copy_query = Identity() - self.out_proj = ScaledLinear( - attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 - ) - - self.in_proj2 = nn.Linear(embed_dim, attention_dim // 2, bias=False) - self.out_proj2 = ScaledLinear(attention_dim // 2, embed_dim, bias=True, - initial_scale=0.05) - # self.whiten_values2 is applied on the values in forward2() - self.whiten_values2 = Whiten(num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025) - self.out_balancer2 = ActivationBalancer(embed_dim, - channel_dim=-1, - min_positive=0.33, - max_positive=0.66, - min_abs=0.005, max_abs=1.0, - min_prob=0.05) - - def forward( self, @@ -1098,378 +1128,235 @@ class RelPositionMultiheadAttention(nn.Module): pos_emb: Tensor, key_padding_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor, Tensor]: + ) -> Tensor: r""" Args: - x: input to be projected to query, key, value - pos_emb: Positional embedding tensor - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. When given a binary mask and a value is True, - the corresponding value on the attention layer will be ignored. When given - a byte mask and a value is non-zero, the corresponding value on the attention - layer will be ignored - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - - Shape: - - Inputs: - - x: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the position - with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - - Returns: (attn_output, attn_weights) - - - attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size, - E is the embedding dimension. - - attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads - and S is the sequence length. + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 2, pos_dim) + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), + interpreted as ([batch_size,] tgt_seq_len, src_seq_len) + saying which positions are allowed to attend to which other positions. + Returns: + a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). """ - x, weights = self.multi_head_attention_forward( - self.in_proj(x), - self.linear_pos(pos_emb), - self.attention_dim, - self.num_heads, - self.dropout, - training=self.training, - key_padding_mask=key_padding_mask, - attn_mask=attn_mask, - ) - return x, weights + x = self.in_proj(x) + pos_emb = self.linear_pos(pos_emb) + query_head_dim = self.query_head_dim + pos_head_dim = self.pos_head_dim + num_heads = self.num_heads + dropout = self.dropout, + training = self.training, + seq_len, batch_size, _ = x.shape - def multi_head_attention_forward( - self, - x_proj: Tensor, - pos: Tensor, - attention_dim: int, - num_heads: int, - dropout_p: float, - training: bool = True, - key_padding_mask: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Optional[Tensor]]: - r""" - Args: - x_proj: the projected input, to be split into query, key, value. - pos: head-specific biases arising from the positional embeddings. - attention_dim: dimension inside attention mechanism - num_heads: parallel attention heads. - dropout_p: probability of an element to be zeroed. - training: apply dropout if is ``True``. - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. This is an binary mask. When the value is True, - the corresponding value on the attention layer will be filled with -inf. - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - - Shape: - Inputs: - - x: :math:`(L, N, 7 * A // 2)` where L is the target sequence length, N is the batch size, A is - the attention dimension. Will be split into (query, key, value, pos). - - pos: :math:`(N, 2*L-1, A//2)` or :math:`(1, 2*L-1, A//2)` where L is the sequence - length, N is the batch size, and A is the attention dim. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions - will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_weights: :math:`(N * H, S, S)` where N is the batch size, - H is the num-heads, S is the sequence length. - """ - - seq_len, bsz, _ = x_proj.size() - - head_dim = attention_dim // num_heads - pos_dim = self.pos_dim # positional-encoding dim per head - assert ( - head_dim * num_heads == attention_dim - ), "attention_dim must be divisible by num_heads" - + query_dim = query_head_dim * num_heads # self-attention - q = x_proj[...,0:attention_dim] - k = x_proj[...,attention_dim:2*attention_dim] - value_dim = attention_dim // 2 - v = x_proj[...,2*attention_dim:2*attention_dim+value_dim] - # p is the position-encoding query, its dimension is num_heads*pos_dim.. - p = x_proj[...,2*attention_dim+value_dim:] + q = x[...,0:query_dim] + k = x[...,query_dim:2*query_dim] + # p is the position-encoding query + p = x[...,2*query_dim:] + assert p.shape[-1] == num_heads * pos_head_dim - k = self.whiten_keys(k) # does nothing in the forward pass. - v = self.whiten_values(v) # does nothing in the forward pass. q = self.copy_query(q) # for diagnostics only, does nothing. + k = self.whiten_keys(k) # does nothing in the forward pass. p = self.copy_pos_query(p) # for diagnostics only, does nothing. - if attn_mask is not None: - assert ( - attn_mask.dtype == torch.float32 - or attn_mask.dtype == torch.float64 - or attn_mask.dtype == torch.float16 - or attn_mask.dtype == torch.uint8 - or attn_mask.dtype == torch.bool - ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( - attn_mask.dtype - ) - if attn_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for attn_mask is deprecated. Use bool tensor instead." - ) - attn_mask = attn_mask.to(torch.bool) + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) + k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) - if attn_mask.dim() == 2: - attn_mask = attn_mask.unsqueeze(0) - if list(attn_mask.size()) != [1, seq_len, seq_len]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) - elif attn_mask.dim() == 3: - if list(attn_mask.size()) != [ - bsz * num_heads, - seq_len, - seq_len, - ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) - else: - raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) - ) - # attn_mask's dim is 3 now. - - # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): - warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." - ) - key_padding_mask = key_padding_mask.to(torch.bool) - - q = q.reshape(seq_len, bsz, num_heads, head_dim) - p = p.reshape(seq_len, bsz, num_heads, pos_dim) - k = k.reshape(seq_len, bsz, num_heads, head_dim) - v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) - - - if key_padding_mask is not None: - assert key_padding_mask.size(0) == bsz, "{} == {}".format( - key_padding_mask.size(0), bsz - ) - assert key_padding_mask.size(1) == seq_len, "{} == {}".format( - key_padding_mask.size(1), seq_len - ) - - - - q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) - p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) - k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) seq_len2 = 2 * seq_len - 1 - pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) - # pos shape now: (batch, head, pos_dim, seq_len2) + pos_emb = pos_emb.reshape(1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1) + # pos shape now: (head, 1, pos_dim, seq_len2) - # (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2) + # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) # [where seq_len2 represents relative position.] - pos_weights = torch.matmul(p, pos) - # the following .as_strided() expression converts the last axis of pos_weights from relative + pos_scores = torch.matmul(p, pos_emb) + # the following .as_strided() expression converts the last axis of pos_scores from relative # to absolute position. I don't know whether I might have got the time-offsets backwards or # not, but let this code define which way round it is supposed to be. - pos_weights = pos_weights.as_strided((bsz, num_heads, seq_len, seq_len), - (pos_weights.stride(0), - pos_weights.stride(1), - pos_weights.stride(2)-pos_weights.stride(3), - pos_weights.stride(3)), - storage_offset=pos_weights.stride(3) * (seq_len - 1)) + pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len), + (pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2)-pos_scores.stride(3), + pos_scores.stride(3)), + storage_offset=pos_scores.stride(3) * (seq_len - 1)) - - # caution: they are really scores at this point. - attn_output_weights = torch.matmul(q, k) + pos_weights + attn_scores = torch.matmul(q, k) + pos_scores if training and random.random() < 0.1: - # This is a harder way of limiting the attention scores to not be too large. - # It incurs a penalty if any of them has an absolute value greater than 50.0. - # this should be outside the normal range of the attention scores. We use - # this mechanism instead of, say, a limit on entropy, because once the entropy - # gets very small gradients through the softmax can become very small, and - # some mechanisms like that become ineffective. - attn_output_weights = penalize_abs_values_gt(attn_output_weights, - limit=25.0, - penalty=1.0e-04) + # This is a harder way of limiting the attention scores to not be + # too large. It incurs a penalty if any of them has an absolute + # value greater than 50.0. this should be outside the normal range + # of the attention scores. We use this mechanism instead of, say, + # something added to the loss function involving the entropy, + # because once the entropy gets very small gradients through the + # softmax can become very small, and we'd get zero derivatives. The + # choices of 1.0e-04 as the scale on the penalty makes this + # mechanism vulnerable to the absolute scale of the loss function, + # but we view this as a failsafe to avoid "implausible" parameter + # values rather than a regularization method that should be active + # under normal circumstances. + attn_scores = penalize_abs_values_gt(attn_scores, + limit=25.0, + penalty=1.0e-04) - - # attn_output_weights: (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, seq_len, seq_len - ) - - assert list(attn_output_weights.size()) == [ - bsz * num_heads, - seq_len, - seq_len, - ] + assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_output_weights.masked_fill_(attn_mask, float("-inf")) - else: - attn_output_weights += attn_mask + assert attn_mask.dtype == torch.bool + attn_scores.masked_fill_(attn_mask, float("-inf")) if key_padding_mask is not None: - attn_output_weights = attn_output_weights.view( - bsz, num_heads, seq_len, seq_len + assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape + + attn_scores = attn_scores.view( + num_heads, batch_size, seq_len, seq_len ) - attn_output_weights = attn_output_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(-1), float("-inf"), ) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, seq_len, seq_len - ) - # Using this version of softmax, defined in scaling.py, - # should save a little of the memory used in backprop by, if - # we are in automatic mixed precision mode (amp) == autocast, - # only storing the half-precision output for backprop purposes. - attn_output_weights = softmax(attn_output_weights, dim=-1) + # We use our own version of softmax, defined in scaling.py, which should + # save a little of the memory used in backprop by, if we are in + # automatic mixed precision mode (amp / autocast), by only storing the + # half-precision output for backprop purposes. + attn_weights = softmax(attn_scores, dim=-1) - attn_output_weights = nn.functional.dropout( - attn_output_weights, p=dropout_p, training=training + if random.random() < 0.001: + self._print_attn_entropy(attn_weights) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training ) - attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, seq_len, - head_dim // 2] - attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(seq_len, bsz, attention_dim // 2) - ) - attn_output = self.out_proj(attn_output) - attn_output = self.out_balancer(attn_output) - - return attn_output, attn_output_weights + return attn_weights - def forward2( + def _print_attn_entropy( + self, + attn_weights: Tensor): + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum( + dim=-1).mean(dim=(1,2)) + logging.info(f"attn_weights_entropy = {attn_weights_entropy}") + + +class SelfAttention(nn.Module): + """ + The simplest possible attention module. This one works with already-computed attention + weights, e.g. as computed by RelPositionMultiheadAttentionWeights. + + Args: + embed_dim: the input and output embedding dimension + num_heads: the number of attention heads + value_dim: the value dimension per head + """ + def __init__( + self, + embed_dim: int, + num_heads: int, + value_head_dim: int, + ) -> None: + super().__init__() + self.in_proj = nn.Linear(embed_dim, + num_heads * value_head_dim, + bias=False) + + # attempt to make the output of `in_proj` uncorrelated within each head + # and all heads having roughly the same magnitude. the hope is to + # improve learning dynamics; this loses no power as there is no constraint + # on the condition number of out_proj. + self.whiten_values = Whiten(num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025) + + self.out_proj = ScaledLinear(num_heads * value_head_dim, + embed_dim, bias=True, + initial_scale=0.05) + + # intended to prevent an observed failure mode where the output of this module is + # dominated by its mean. + self.out_balancer = ActivationBalancer(embed_dim, + channel_dim=-1, + min_positive=0.33, + max_positive=0.66, + min_abs=0.005, max_abs=1.0, + min_prob=0.05) + + def forward( self, x: Tensor, attn_weights: Tensor, ) -> Tensor: """ - Second forward function, where we re-use the attn_weights returned by the first forward function - but with different input. Args: - x: input, of shape (seq_len, batch_size, embed_dim) - attn_weights: attention weights returned by forward(), of shape (batch_size * num_heads, seq_len, seq_len) + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. Returns: - output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) + a tensor with the same shape as x. """ - num_heads = self.num_heads - (seq_len, bsz, embed_dim) = x.shape - head_dim = self.attention_dim // num_heads - # v: (tgt_len, bsz, embed_dim // 2) - v = self.in_proj2(x) - v = self.whiten_values2(v) # does nothing in the forward pass. - v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) - # now v: (bsz * num_heads, seq_len, head_dim // 2) - attn_output = torch.bmm(attn_weights, v) + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + x = self.whiten_values(x) # does nothing in the forward pass. + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] - if random.random() < 0.001 or __name__ == "__main__": - self._print_attn_stats(attn_weights, attn_output) + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) - # attn_output: (bsz * num_heads, seq_len, head_dim) - attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(seq_len, bsz, self.attention_dim // 2) - ) - # returned value is of shape (seq_len, bsz, embed_dim), like x. - attn_output = self.out_proj2(attn_output) - attn_output = self.out_balancer2(attn_output) - return attn_output + x = x.permute(2, 1, 0, 3).contiguous().view( + seq_len, batch_size, num_heads * value_head_dim) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + x = self.out_balancer(x) + + return x - def _print_attn_stats( - self, - attn_weights: Tensor, - attn_output: Tensor): - # attn_weights: (batch_size * num_heads, seq_len, seq_len) - # attn_output: (bsz * num_heads, seq_len, head_dim) - (n, seq_len, head_dim) = attn_output.shape - num_heads = self.num_heads - bsz = n // num_heads - - with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): - attn_weights = attn_weights.to(torch.float32) - attn_output = attn_output.to(torch.float32) - attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum( - dim=-1).reshape(bsz, num_heads, seq_len).mean(dim=(0,2)) - attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim) - attn_output = attn_output.permute(1, 0, 2, 3).reshape(num_heads, bsz * seq_len, head_dim) - attn_output_mean = attn_output.mean(dim=1, keepdim=True) - attn_output = attn_output - attn_output_mean - attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / (bsz * seq_len) - # attn_covar: (num_heads, head_dim, head_dim) - #eigs, _ = torch.symeig(attn_covar) - #logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") - - attn_covar = _diag(attn_covar).mean(dim=1) # (num_heads,) - embed_dim = self.in_proj2.weight.shape[1] - in_proj_covar = (self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2).mean(dim=(1,2)) - out_proj_covar = (self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2).mean(dim=(0,2)) - logging.info(f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}") - - - - -class ModifiedSEModule(nn.Module): +class AttentionSqueeze(nn.Module): """ A modified version of Squeeze-and-Excite, where the nonliearity happens in the full dim and we just project to a small bottleneck dimension. """ def __init__(self, - d_model: int, + embed_dim: int, bottleneck_dim: int = 16): super().__init__() self.bottleneck_dim = bottleneck_dim - self.in_proj = nn.Linear(d_model, d_model, + self.in_proj = nn.Linear(embed_dim, embed_dim, bias=False) - self.to_bottleneck_proj = ScaledLinear(d_model, - bottleneck_dim, - bias=False) + self.to_bottleneck_proj = nn.Linear(embed_dim, + bottleneck_dim, + bias=False) # Caution: this cannot work correctly with an extremeley small batch size, e.g. if # we were training with a single very long audio sequence, or just 2 or 3 sequences @@ -1477,7 +1364,7 @@ class ModifiedSEModule(nn.Module): # (although when the grads get back past the averaging operation they would # be quite small and would probably not hurt the rest of the model much.) self.balancer = ActivationBalancer( - d_model, channel_dim=-1, + embed_dim, channel_dim=-1, min_positive=0.05, max_positive=0.95, min_abs=0.1, max_abs=50.0, @@ -1486,9 +1373,9 @@ class ModifiedSEModule(nn.Module): ) self.activation = DoubleSwish() - self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, d_model) + self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, embed_dim) - self.out_proj = ScaledLinear(d_model, d_model, + self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=False, initial_scale=0.1) self.out_whiten = Whiten(num_groups=1, @@ -1499,53 +1386,52 @@ class ModifiedSEModule(nn.Module): def forward(self, x: Tensor, - attn_weights: Tensor, - head_idx: int): + attn_weights: Tensor): """ Args: - x: a Tensor of shape (T, N, C) -attn_weights: a Tensor of shape (N * num_heads, seq_len, seq_len), we will only use the head indexed - `attn_weights_index` - head_idx: indicates which head to choose from attn_weights + x: a Tensor of shape (seq_len, batch_size, num_channels) +attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) Returns: - a Tensor of shape (T, N, C) + a Tensor with the same shape as x """ - (T, N, d_model) = x.shape - num_heads = attn_weights.shape[0] // N - attn_weights = attn_weights.reshape(N, num_heads, T, T) - attn_weights = attn_weights[:,head_idx] # (N, T, T) + num_heads = attn_weights.shape[0] + bottleneck = self.to_bottleneck_proj(x) # (seq_len, batch_size, bottleneck_dim) + (seq_len, batch_size, bottleneck_dim) = bottleneck.shape + head_dim = bottleneck_dim // num_heads + bottleneck = bottleneck.reshape(seq_len, batch_size, num_heads, head_dim).permute( + 2, 1, 0, 3) # (num_heads, batch_size, seq_len, head_dim) - bottleneck = self.to_bottleneck_proj(x) # (T, N, C) - bottleneck = bottleneck.transpose(0, 1) # (N, T, bottleneck_dim) - - # (N, T, T) x (N, T, bottleneck_dim) -> (N, T, bottleneck_dim) - bottleneck = torch.bmm(attn_weights, bottleneck) + # (num_heads, batch_size, seq_len, seq_len) x (num_heads, batch_size, seq_len, head_dim) + # -> (num_heads, batch_size, seq_len, head_dim) + bottleneck = torch.matmul(attn_weights, bottleneck) bottleneck = self.balancer(bottleneck) bottleneck = self.activation(bottleneck) - bottleneck = bottleneck.transpose(0, 1) # (T, N, bottleneck_dim) + bottleneck = bottleneck.permute(2, 1, 0, 3) # (seq_len, batch_size, num_heads, head_dim) + bottleneck = bottleneck.reshape(seq_len, batch_size, bottleneck_dim) scales = self.from_bottleneck_proj(bottleneck) - x = self.in_proj(x) x = x * scales - return self.out_whiten(self.out_proj(x)) + x = self.out_proj(x) + x = self.out_whiten(x) + return x class FeedforwardModule(nn.Module): """Feedforward module in Zipformer model. """ def __init__(self, - d_model: int, + embed_dim: int, feedforward_dim: int, dropout: float): super(FeedforwardModule, self).__init__() - self.in_proj = nn.Linear(d_model, feedforward_dim) + self.in_proj = nn.Linear(embed_dim, feedforward_dim) self.balancer = ActivationBalancer(feedforward_dim, channel_dim=-1, max_abs=10.0, min_prob=0.25) self.activation = DoubleSwish() self.dropout = nn.Dropout(dropout) - self.out_proj = ScaledLinear(feedforward_dim, d_model, + self.out_proj = ScaledLinear(feedforward_dim, embed_dim, initial_scale=0.01) def forward(self, @@ -1596,36 +1482,35 @@ class NonlinAttentionModule(nn.Module): def forward(self, x: Tensor, attn_weights: Tensor, - head_idx: int, ) -> Tensor: """. Args: - x: a Tensor of shape (T, N, C), i.e. (time, batch, channels) - attn_weights: a Tensor of shape (N * num_heads, seq_len, seq_len), we will only use the 1st head. - head_idx: indicates which head to choose from attn_weights + x: a Tensor of shape (seq_len, batch_size, num_channels) +attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) Returns: - a Tensor of shape (T, N, C) + a Tensor with the same shape as x """ - s = self.to_scale(x) v = self.to_value(x) if self.training and random.random() < 0.02: # prevent the inputs to the sigmoid from getting very large (this is - # unlikely to happen in this particular module, so giving this path - # a very small probability). - s = penalize_abs_values_gt(s, limit=20.0, penalty=1.0e-04) + # hopefully quite a rare phenomenon, so we are giving this path a + # very small probability to save time). + s = penalize_abs_values_gt(s, limit=20.0, penalty=1.0e-04) # GLU mechanism x = s.sigmoid() * v - (T, N, d_model) = x.shape - num_heads = attn_weights.shape[0] // N - attn_weights = attn_weights.reshape(N, num_heads, T, T) - attn_weights = attn_weights[:,head_idx] # (N, T, T) - x = x.transpose(0, 1) # (N, T, C) - x = torch.bmm(attn_weights, x) + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = torch.matmul(attn_weights, x) + # now x: (num_heads, batch_size, seq_len, head_dim) x = self.deriv_balancer(x) - x = x.transpose(0, 1) # (T, N, C) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) x = self.activation(x) x = self.out_proj(x) @@ -1745,7 +1630,7 @@ class ConvolutionModule(nn.Module): class Conv2dSubsampling(nn.Module): - """Convolutional 2D subsampling (to 1/4 length). + """Convolutional 2D subsampling (to 1/2 length). Convert an input of shape (N, T, idim) to an output with shape (N, T', odim), where @@ -1979,7 +1864,7 @@ def _test_zipformer_main(): # Just make sure the forward pass runs. c = Zipformer( - num_features=feature_dim, encoder_dims=(64,96), encoder_unmasked_dims=(48,64), nhead=(4,4) + num_features=feature_dim, encoder_dim=(64,96), encoder_unmasked_dim=(48,64), num_heads=(4,4) ) batch_size = 5 seq_len = 20