Implement min,max param rms

This commit is contained in:
Daniel Povey 2022-06-18 22:22:23 +08:00
parent 3fb3cc4e23
commit 29dfe37f22

View File

@ -88,7 +88,8 @@ class NeutralGradient(Optimizer):
grad_eps=1e-8, grad_eps=1e-8,
param_eps=1.0e-06, param_eps=1.0e-06,
param_rel_eps=1.0e-04, param_rel_eps=1.0e-04,
param_max=10.0, param_max_rms=2.0,
param_min_rms=1.0e-05,
max_fullcov_size=1023, max_fullcov_size=1023,
estimate_period=2000, estimate_period=2000,
stats_steps=200, stats_steps=200,
@ -115,8 +116,10 @@ class NeutralGradient(Optimizer):
raise ValueError("Invalid grad_eps value: {}".format(grad_eps)) raise ValueError("Invalid grad_eps value: {}".format(grad_eps))
if not 0.0 < param_eps < 0.1: if not 0.0 < param_eps < 0.1:
raise ValueError("Invalid param_eps value: {}".format(param_eps)) raise ValueError("Invalid param_eps value: {}".format(param_eps))
if not 0.1 < param_max: if not 0.1 < param_max_rms < 1.0e+10:
raise ValueError("Invalid param_max value: {}".format(param_max)) raise ValueError("Invalid param_max_rms value: {}".format(param_max_rms))
if not 0.0 < param_min_rms < 0.5:
raise ValueError("Invalid param_min_rms value: {}".format(param_min_rms))
if not 0 < estimate_period: if not 0 < estimate_period:
raise ValueError("Invalid estimate_period value: {}".format(estimate_period)) raise ValueError("Invalid estimate_period value: {}".format(estimate_period))
if not 0 < stats_steps < estimate_period: if not 0 < stats_steps < estimate_period:
@ -134,7 +137,8 @@ class NeutralGradient(Optimizer):
grad_eps=grad_eps, grad_eps=grad_eps,
param_eps=param_eps, param_eps=param_eps,
param_rel_eps=param_rel_eps, param_rel_eps=param_rel_eps,
param_max=param_max, param_max_rms=param_max_rms,
param_min_rms=param_min_rms,
max_fullcov_size=max_fullcov_size, max_fullcov_size=max_fullcov_size,
estimate_period=estimate_period, estimate_period=estimate_period,
stats_steps=stats_steps, stats_steps=stats_steps,
@ -171,7 +175,8 @@ class NeutralGradient(Optimizer):
grad_eps = group["grad_eps"] grad_eps = group["grad_eps"]
param_eps = group["param_eps"] param_eps = group["param_eps"]
param_rel_eps = group["param_rel_eps"] param_rel_eps = group["param_rel_eps"]
param_max = group["param_max"] param_max_rms = group["param_max_rms"]
param_min_rms = group["param_min_rms"]
max_fullcov_size = group["max_fullcov_size"] max_fullcov_size = group["max_fullcov_size"]
estimate_period = group["estimate_period"] estimate_period = group["estimate_period"]
stats_steps = group["stats_steps"] stats_steps = group["stats_steps"]
@ -328,13 +333,24 @@ class NeutralGradient(Optimizer):
scale_alpha = -lr * (scale_bias_correction2 ** 0.5) * scale_speed * (1-beta1) scale_alpha = -lr * (scale_bias_correction2 ** 0.5) * scale_speed * (1-beta1)
enforce_rms_period = 2
scale_norm_deriv = (scale_deriv / scale_denom)
if step % enforce_rms_period == 0:
# Occasionally enforce the min/max-rms constraint.
param_rms = (p ** 2).mean().sqrt()
is_too_small = (param_rms < param_min_rms)
is_too_large = (param_rms > param_max_rms)
scale_norm_deriv.masked_fill_(is_too_small,
-(n_cached_grads ** 0.5) * enforce_rms_period)
scale_norm_deriv.masked_fill_(is_too_large,
(n_cached_grads ** 0.5) * enforce_rms_period)
#print(f"param_rms={param_rms}, param_min,max_rms={param_min_rms},{param_max_rms} is_too_small={is_too_small},is_too_large={is_too_large},scale_norm_deriv={scale_norm_deriv}")
# scale_delta is going to be the change in log-scale (before momentum), i.e. if # scale_delta is going to be the change in log-scale (before momentum), i.e. if
# p = underlying_param * scale.exp(), # p = underlying_param * scale.exp(),
# delta is the change in `scale`. # delta is the change in `scale`.
scale_delta = scale_alpha * (scale_deriv / scale_denom) scale_delta = scale_alpha * scale_norm_deriv
#if random.random() < 0.01:
# print("scale_delta = ", scale_delta)
delta.add_(p, alpha=scale_delta) delta.add_(p, alpha=scale_delta)
exp_avg_sq = state["exp_avg_sq"] exp_avg_sq = state["exp_avg_sq"]
@ -416,13 +432,7 @@ class NeutralGradient(Optimizer):
if random.random() < 0.0001: if random.random() < 0.0001:
print(f"Delta rms = {(delta**2).mean().item()}, shape = {delta.shape}") print(f"Delta rms = {(delta**2).mean().item()}, shape = {delta.shape}")
#max_abs = delta.abs().max()
#if max_abs > 1.0:
# print("Max_abs_change = ", max_abs)
p.add_(delta) p.add_(delta)
if step % 10 == 0:
p.clamp_(min=-param_max, max=param_max)
state["step"] += 1 state["step"] += 1