From 29dfe37f2254af1e01ebf8a3fd054efa15e7ced6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 18 Jun 2022 22:22:23 +0800 Subject: [PATCH] Implement min,max param rms --- .../ASR/pruned_transducer_stateless7/optim.py | 40 ++++++++++++------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 1ea201fce..5a5e2c823 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -88,7 +88,8 @@ class NeutralGradient(Optimizer): grad_eps=1e-8, param_eps=1.0e-06, param_rel_eps=1.0e-04, - param_max=10.0, + param_max_rms=2.0, + param_min_rms=1.0e-05, max_fullcov_size=1023, estimate_period=2000, stats_steps=200, @@ -115,8 +116,10 @@ class NeutralGradient(Optimizer): raise ValueError("Invalid grad_eps value: {}".format(grad_eps)) if not 0.0 < param_eps < 0.1: raise ValueError("Invalid param_eps value: {}".format(param_eps)) - if not 0.1 < param_max: - raise ValueError("Invalid param_max value: {}".format(param_max)) + if not 0.1 < param_max_rms < 1.0e+10: + raise ValueError("Invalid param_max_rms value: {}".format(param_max_rms)) + if not 0.0 < param_min_rms < 0.5: + raise ValueError("Invalid param_min_rms value: {}".format(param_min_rms)) if not 0 < estimate_period: raise ValueError("Invalid estimate_period value: {}".format(estimate_period)) if not 0 < stats_steps < estimate_period: @@ -134,7 +137,8 @@ class NeutralGradient(Optimizer): grad_eps=grad_eps, param_eps=param_eps, param_rel_eps=param_rel_eps, - param_max=param_max, + param_max_rms=param_max_rms, + param_min_rms=param_min_rms, max_fullcov_size=max_fullcov_size, estimate_period=estimate_period, stats_steps=stats_steps, @@ -171,7 +175,8 @@ class NeutralGradient(Optimizer): grad_eps = group["grad_eps"] param_eps = group["param_eps"] param_rel_eps = group["param_rel_eps"] - param_max = group["param_max"] + param_max_rms = group["param_max_rms"] + param_min_rms = group["param_min_rms"] max_fullcov_size = group["max_fullcov_size"] estimate_period = group["estimate_period"] stats_steps = group["stats_steps"] @@ -328,13 +333,24 @@ class NeutralGradient(Optimizer): scale_alpha = -lr * (scale_bias_correction2 ** 0.5) * scale_speed * (1-beta1) + + enforce_rms_period = 2 + scale_norm_deriv = (scale_deriv / scale_denom) + if step % enforce_rms_period == 0: + # Occasionally enforce the min/max-rms constraint. + param_rms = (p ** 2).mean().sqrt() + is_too_small = (param_rms < param_min_rms) + is_too_large = (param_rms > param_max_rms) + scale_norm_deriv.masked_fill_(is_too_small, + -(n_cached_grads ** 0.5) * enforce_rms_period) + scale_norm_deriv.masked_fill_(is_too_large, + (n_cached_grads ** 0.5) * enforce_rms_period) + #print(f"param_rms={param_rms}, param_min,max_rms={param_min_rms},{param_max_rms} is_too_small={is_too_small},is_too_large={is_too_large},scale_norm_deriv={scale_norm_deriv}") + # scale_delta is going to be the change in log-scale (before momentum), i.e. if # p = underlying_param * scale.exp(), # delta is the change in `scale`. - scale_delta = scale_alpha * (scale_deriv / scale_denom) - #if random.random() < 0.01: - # print("scale_delta = ", scale_delta) - + scale_delta = scale_alpha * scale_norm_deriv delta.add_(p, alpha=scale_delta) exp_avg_sq = state["exp_avg_sq"] @@ -416,13 +432,7 @@ class NeutralGradient(Optimizer): if random.random() < 0.0001: print(f"Delta rms = {(delta**2).mean().item()}, shape = {delta.shape}") - #max_abs = delta.abs().max() - #if max_abs > 1.0: - # print("Max_abs_change = ", max_abs) p.add_(delta) - if step % 10 == 0: - p.clamp_(min=-param_max, max=param_max) - state["step"] += 1