diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 42e18f31a..8cc76c9c1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -66,12 +66,12 @@ class NeutralGradient(Optimizer): params, lr=1e-2, betas=(0.9, 0.98, 0.98), - scale_speed=0.025, + scale_speed=0.05, eps=1e-8, param_eps=1.0e-05, cond_eps=1.0e-10, param_max=10.0, - min_diag_smooth=0.2, + min_diag_smooth=0.5, max_size=1023, stats_period=1, estimate_period=200, @@ -378,14 +378,17 @@ class NeutralGradient(Optimizer): assert step % stats_period == 0 norm_step = step // stats_period - def update_ref_exp_avg_sq(ref_exp_avg_sq, grad_cov_smoothed, dim): + scale_change = True + + def update_ref_exp_avg_sq(ref_exp_avg_sq, grad_cov_smoothed, dim, + scale_change): """ Updates ref_exp_avg_sq, taking into account `grad_cov_smoothed` which is a tensor of shape (p.shape[dim],) containing the smoothed diagonal covariance of grads for dimension `dim`. Returns new ref_exp_avg_sq """ - if dim > 0: + if not scale_change: # We only want to "count" the overall scale of the gradients' variance # (i.e. mean variance) once, so for all dims other than 0 we normalize # this out. @@ -412,12 +415,14 @@ class NeutralGradient(Optimizer): grad_var_smoothed = grad_cov + eps*eps ref_exp_avg_sq = update_ref_exp_avg_sq(ref_exp_avg_sq, grad_var_smoothed, - dim) - if dim != 0: - # If dim != 0, make the grad_var_smoothed and - # param_var_smoothed have the same mean, because when - # preconditioning gradients, we only want to count the - # overall change in scale once. + dim, scale_change) + if scale_change: + scale_change = False # only use the scale change once + else: + # For all but the first non-trivial dim, make the + # grad_var_smoothed and param_var_smoothed have the same + # mean, because when preconditioning gradients, we only want + # to count the overall change in scale once. grad_var_smoothed *= (param_var_smoothed.sum() / grad_var_smoothed.sum()) proj[:] = (param_var_smoothed / grad_var_smoothed).sqrt() else: @@ -432,9 +437,11 @@ class NeutralGradient(Optimizer): min_diag_smooth=min_diag_smooth) ref_exp_avg_sq = update_ref_exp_avg_sq(ref_exp_avg_sq, grad_cov_smoothed.diag(), - dim) + dim, scale_change) - if dim != 0: + if scale_change: + scale_change = False # only use the scale change once + else: # If dim != 0, make the traces of grad_cov_smoothed and # param_cov_smoothed be the same, because when # preconditioning gradients, we only want to count the