diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 1db062bd4..afaf864f0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -105,7 +105,8 @@ class Zipformer(EncoderInterface): super(Zipformer, self).__init__() def _to_tuple(x): - """ Converts a single int or a 1-tuple of an int to a tuple with the same length as output_downsampling_factor""" + """ Converts a single int or a 1-tuple of an int to a tuple with the same length + as downsampling_factor""" if isinstance(x, int): x = (x,) if len(x) == 1: @@ -152,15 +153,15 @@ class Zipformer(EncoderInterface): num_encoders = len(downsampling_factor) for i in range(num_encoders): encoder_layer = ZipformerEncoderLayer( - encoder_dim[i], - pos_dim, - num_heads[i], - query_head_dim[i], - pos_head_dim[i], - value_head_dim[i], - feedforward_dim[i], - dropout, - cnn_module_kernel[i], + embed_dim=encoder_dim[i], + pos_dim=pos_dim, + num_heads=num_heads[i], + query_head_dim=query_head_dim[i], + pos_head_dim=pos_head_dim[i], + value_head_dim=value_head_dim[i], + feedforward_dim=feedforward_dim[i], + dropout=dropout, + cnn_module_kernel=cnn_module_kernel[i], ) # For the segment of the warmup period, we let the Conv2dSubsampling @@ -168,8 +169,8 @@ class Zipformer(EncoderInterface): encoder = ZipformerEncoder( encoder_layer, num_encoder_layers[i], - pos_dim, - dropout, + pos_dim=pos_dim, + dropout=dropout, warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), warmup_end=warmup_batches * (i + 2) / (num_encoders + 1) ) @@ -372,8 +373,9 @@ class ZipformerEncoderLayer(nn.Module): self.batch_count = 0 self.self_attn_weights = RelPositionMultiheadAttentionWeights( - embed_dim, pos_dim, num_heads, - query_head_dim, pos_head_dim, dropout=0.0, + embed_dim, pos_dim=pos_dim, num_heads=num_heads, + query_head_dim=query_head_dim, pos_head_dim=pos_head_dim, + dropout=0.0, ) self.self_attn1 = SelfAttention(embed_dim, num_heads, @@ -1099,17 +1101,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module): bias=False, initial_scale=0.05) - # this is to stop a failure mode where the output gets very small and is - # dominated by the mean (the min_positive and max_positive will stop the mean - # being much larger than the variance). Make min_abs very small because even for normal, - # functional self_attn layers, the output rms can be very small. - self.out_balancer = ActivationBalancer(embed_dim, - channel_dim=-1, - min_positive=0.33, - max_positive=0.66, - min_abs=0.005, max_abs=1.0, - min_prob=0.05) - # the following are for diagnosics only, see --print-diagnostics option self.copy_pos_query = Identity() @@ -1141,8 +1132,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module): query_head_dim = self.query_head_dim pos_head_dim = self.pos_head_dim num_heads = self.num_heads - dropout = self.dropout, - training = self.training, seq_len, batch_size, _ = x.shape @@ -1165,6 +1154,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) + # time1 refers to target, time2 refers to source. q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) @@ -1189,7 +1179,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): attn_scores = torch.matmul(q, k) + pos_scores - if training and random.random() < 0.1: + if self.training and random.random() < 0.1: # This is a harder way of limiting the attention scores to not be # too large. It incurs a penalty if any of them has an absolute # value greater than 50.0. this should be outside the normal range @@ -1214,10 +1204,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module): if key_padding_mask is not None: assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape - - attn_scores = attn_scores.view( - num_heads, batch_size, seq_len, seq_len - ) attn_scores = attn_scores.masked_fill( key_padding_mask.unsqueeze(1), float("-inf"),