diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py index 2fe84cf21..74d6bf569 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py @@ -603,7 +603,7 @@ class Cain(Optimizer): scale_denom = (scale_exp_avg_sq.sqrt()).add_(group["eps"]) - scale_alpha = -lr*(scale_bias_correction2 ** 0.5) + scale_alpha = -lr * (scale_bias_correction2 ** 0.5) * (1-beta1) # scale_delta is going to be the change in log-scale (before momentum), i.e. if # p = underlying_param * scale.exp(), @@ -617,7 +617,8 @@ class Cain(Optimizer): # specified limit, start to decay the parameters (and ignore the computed # delta). scale_delta = torch.where(state["param_rms"] > rms_max, - torch.tensor(-1.0e-03, device=scale_delta.device, + torch.tensor(-1.0e-03 * (1-beta1), + device=scale_delta.device, dtype=scale_delta.dtype), scale_delta)