diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index ff8fbb32c..2a7c3d89a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -561,8 +561,11 @@ class ScaledAdam(BatchedOptimizer): # when the param gets too small, just don't shrink it any further. scale_step.masked_fill_(is_too_small, 0.0) - # when it gets too large, stop it from getting any larger. - scale_step.masked_fill_(is_too_large, -size_lr * size_update_period) + + # and ensure the parameter rms after update never exceeds param_max_rms. + scale_step = torch.minimum(scale_step, + (param_max_rms - param_rms) / param_rms) + delta = state["delta"] # the factor of (1-beta1) relates to momentum. delta.add_(p * scale_step, alpha=(1 - beta1)) @@ -776,7 +779,9 @@ class Eden(LRScheduler): def _test_eden(): m = torch.nn.Linear(100, 100) - optim = ScaledAdam(m.parameters(), lr=0.03) + parameters_names = [ [ x[0] for x in m.named_parameters() ] ] + optim = ScaledAdam(m.parameters(), lr=0.03, + parameters_names=parameters_names) scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True) @@ -987,7 +992,9 @@ def _test_scaled_adam(hidden_dim: int): if iter == 0: optim = Eve(m.parameters(), lr=0.003) elif iter == 1: - optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) + parameters_names = [ [ x[0] for x in m.named_parameters() ] ] + optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0, + parameters_names=parameters_names) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) start = timeit.default_timer()