diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 178b57c53..6e2a4deb4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -649,7 +649,7 @@ class AttentionDownsample(torch.nn.Module): scores = (src * self.query).sum(dim=-1, keepdim=True) scores = penalize_abs_values_gt(scores, - limit=5.0, + limit=10.0, penalty=1.0e-04) weights = scores.softmax(dim=1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 52b03f692..5b4fece3d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -482,7 +482,7 @@ class ActivationBalancer(torch.nn.Module): scale_gain_factor: float = 0.02, min_abs: float = 0.2, max_abs: float = 100.0, - min_prob: float = 0.05, + min_prob: float = 0.1, ): super(ActivationBalancer, self).__init__() self.num_channels = num_channels diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 1493da1fb..ad1348e7f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -850,7 +850,7 @@ def train_one_epoch( # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0: + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: logging.warn("Grad scale is small: {cur_grad_scale}")