diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 9c47abb30..17d9f7672 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -722,7 +722,13 @@ class SimpleCombiner(torch.nn.Module): weight1 = (src1 * self.to_weight1).sum(dim=-1, keepdim=True) weight2 = (src2 * self.to_weight2).sum(dim=-1, keepdim=True) - weight = (weight1 + weight2).sigmoid() + logit = (weight1 + weight2) + + if self.training and random.random() < 0.1: + logit = penalize_abs_values_gt(logit, + limit=10.0, + penalty=1.0e-04) + weight = logit.sigmoid() src2_part1 = src2[...,:dim1] part1 = src1 * weight + src2_part1 * (1.0 - weight) @@ -1569,9 +1575,10 @@ class AttentionCombine(nn.Module): scores = scores.masked_fill(mask, float('-inf')) - scores = penalize_abs_values_gt(scores, - limit=10.0, - penalty=1.0e-04) + if self.training and random.random() < 0.1: + scores = penalize_abs_values_gt(scores, + limit=10.0, + penalty=1.0e-04) weights = scores.softmax(dim=1)