diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index a6327b30e..66c03d17c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -370,8 +370,9 @@ class ConformerEncoderLayer(nn.Module): warmup_value = self.get_warmup_value(warmup_count) if warmup_value < 1.0 and self.training: delta = src - src_orig - keep_prob = 0.5 * (1. + warmup_value) - delta = delta * (torch.rand_like(delta) < keep_prob) + keep_prob = 0.5 + 0.5 * warmup_value + # the :1 means the mask is chosen per frame. + delta = delta * (torch.rand_like(delta[...,:1]) < keep_prob) src = src_orig + delta @@ -482,7 +483,6 @@ class ConformerEncoder(nn.Module): output = 0.5 * (next_output + output) output = output * feature_mask - output_mean = output.abs().mean().item() return output