diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index bbfb292b4..711db31b6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -452,20 +452,25 @@ class ZipformerEncoderLayer(nn.Module): self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2) - self.norm_final = BasicNorm(embed_dim) + self.norm = BasicNorm(embed_dim) self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) - # try to ensure the output is close to zero-mean (or at least, zero-median). - self.balancer = Balancer( + self.balancer1 = Balancer( embed_dim, channel_dim=-1, min_positive=0.45, max_positive=0.55, - min_abs=1.0, max_abs=6.0, + min_abs=1.0, max_abs=4.0, ) self.whiten = Whiten(num_groups=1, whitening_limit=_whitening_schedule(4.0, ratio=3.0), prob=(0.025, 0.25), grad_scale=0.01) + self.balancer2 = Balancer( + embed_dim, channel_dim=-1, + min_positive=0.45, max_positive=0.55, + min_abs=0.5, max_abs=2.0, + ) + def remove_attention_weights(self): self.self_attn_weights = None @@ -571,12 +576,15 @@ class ZipformerEncoderLayer(nn.Module): src = src + self.feed_forward2(src) - src = self.balancer(src) - src = self.norm_final(src) + src = self.balancer1(src) + src = self.norm(src) bypass_scale = self.get_bypass_scale(src.shape[1]) - src = src * bypass_scale + src_orig * (1.0 - bypass_scale) + # the next line equivalent to: src = src * bypass_scale + src_orig * + # (1.0 - bypass_scale), but more memory efficient for backprop. + src = src_orig + (src - src_orig) * bypass_scale + src = self.balancer2(src) src = self.whiten(src) return src, attn_weights