diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 210f89cc6..222946f2b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -294,6 +294,8 @@ class ConformerEncoderLayer(nn.Module): self.norm_final = BasicNorm(d_model) + self.bypass_scale = nn.Parameter(torch.tensor(0.5)) + # try to ensure the output is close to zero-mean (or at least, zero-median). self.balancer = ActivationBalancer( d_model, channel_dim=-1, @@ -367,13 +369,13 @@ class ConformerEncoderLayer(nn.Module): src = self.norm_final(self.balancer(src)) warmup_value = self.get_warmup_value(warmup_count) + + delta = src - src_orig if warmup_value < 1.0 and self.training: - delta = src - src_orig keep_prob = 0.25 + 0.75 * warmup_value # the :1 means the mask is chosen per frame. delta = delta * (torch.rand_like(delta[...,:1]) < keep_prob) - src = src_orig + delta - + src = src_orig + delta * self.bypass_scale return src, attn_scores_out @@ -514,8 +516,6 @@ class ConformerEncoder(nn.Module): src_key_padding_mask=src_key_padding_mask, warmup_count=warmup_count, ) - # this seemed to be helpful... - output = 0.5 * (next_output + output) output = output * feature_mask