diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 8fe3fffaa..8d16d7b1d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -54,9 +54,6 @@ class LearnedGradient(Optimizer): diagonalize_period: How often we re-diagonalize the gradient covariance; this gets multiplied by lr_est_period, i.e. we diagonalize less frequently than we update the learning-rate matrixes. - max_step_scale: Value that determines limit on how large steps can be per element, - intended to prevent instability during early steps before the learning-rate - matrices are well learned. """ def __init__( self, @@ -71,9 +68,8 @@ class LearnedGradient(Optimizer): param_max_rms=2.0, lr_mat_min=0.01, lr_mat_max=4.0, - lr_est_period=10, + lr_est_period=5, diagonalize_period=4, - max_step_scale=2.0 ): @@ -90,7 +86,6 @@ class LearnedGradient(Optimizer): lr_mat_max=lr_mat_max, lr_est_period=lr_est_period, diagonalize_period=diagonalize_period, - max_step_scale=max_step_scale, ) super(LearnedGradient, self).__init__(params, defaults) @@ -125,7 +120,6 @@ class LearnedGradient(Optimizer): param_max_rms = group["param_max_rms"] lr_est_period = group["lr_est_period"] diagonalize_period = group["diagonalize_period"] - max_step_scale = group["max_step_scale"] for p in group["params"]: if p.grad is None: @@ -232,6 +226,8 @@ class LearnedGradient(Optimizer): step = state["step"] + delta = state["delta"] + delta.mul_(beta1) numel = p.numel() if numel > 1: # Update the size/scale of p, and set param_rms @@ -248,8 +244,7 @@ class LearnedGradient(Optimizer): beta1, beta2, step, size_lr, param_min_rms, param_max_rms) - delta = state["delta"] - delta.mul_(beta1) + if numel == 1: # For parameters with very few elements we just use a form # of Adam with a scale factor to reflect the overall @@ -482,8 +477,6 @@ class LearnedGradient(Optimizer): M_delta = torch.matmul(M, proj_delta) # (eq.5), M_delta = M proj_delta - alpha = -meta_lr * (1-beta1) - M_delta_full = M_delta.reshape(M_full.shape) this_delta = M_delta_full.transpose(dim, -1)