mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Implement min,max param rms
This commit is contained in:
parent
3fb3cc4e23
commit
29dfe37f22
@ -88,7 +88,8 @@ class NeutralGradient(Optimizer):
|
||||
grad_eps=1e-8,
|
||||
param_eps=1.0e-06,
|
||||
param_rel_eps=1.0e-04,
|
||||
param_max=10.0,
|
||||
param_max_rms=2.0,
|
||||
param_min_rms=1.0e-05,
|
||||
max_fullcov_size=1023,
|
||||
estimate_period=2000,
|
||||
stats_steps=200,
|
||||
@ -115,8 +116,10 @@ class NeutralGradient(Optimizer):
|
||||
raise ValueError("Invalid grad_eps value: {}".format(grad_eps))
|
||||
if not 0.0 < param_eps < 0.1:
|
||||
raise ValueError("Invalid param_eps value: {}".format(param_eps))
|
||||
if not 0.1 < param_max:
|
||||
raise ValueError("Invalid param_max value: {}".format(param_max))
|
||||
if not 0.1 < param_max_rms < 1.0e+10:
|
||||
raise ValueError("Invalid param_max_rms value: {}".format(param_max_rms))
|
||||
if not 0.0 < param_min_rms < 0.5:
|
||||
raise ValueError("Invalid param_min_rms value: {}".format(param_min_rms))
|
||||
if not 0 < estimate_period:
|
||||
raise ValueError("Invalid estimate_period value: {}".format(estimate_period))
|
||||
if not 0 < stats_steps < estimate_period:
|
||||
@ -134,7 +137,8 @@ class NeutralGradient(Optimizer):
|
||||
grad_eps=grad_eps,
|
||||
param_eps=param_eps,
|
||||
param_rel_eps=param_rel_eps,
|
||||
param_max=param_max,
|
||||
param_max_rms=param_max_rms,
|
||||
param_min_rms=param_min_rms,
|
||||
max_fullcov_size=max_fullcov_size,
|
||||
estimate_period=estimate_period,
|
||||
stats_steps=stats_steps,
|
||||
@ -171,7 +175,8 @@ class NeutralGradient(Optimizer):
|
||||
grad_eps = group["grad_eps"]
|
||||
param_eps = group["param_eps"]
|
||||
param_rel_eps = group["param_rel_eps"]
|
||||
param_max = group["param_max"]
|
||||
param_max_rms = group["param_max_rms"]
|
||||
param_min_rms = group["param_min_rms"]
|
||||
max_fullcov_size = group["max_fullcov_size"]
|
||||
estimate_period = group["estimate_period"]
|
||||
stats_steps = group["stats_steps"]
|
||||
@ -328,13 +333,24 @@ class NeutralGradient(Optimizer):
|
||||
|
||||
scale_alpha = -lr * (scale_bias_correction2 ** 0.5) * scale_speed * (1-beta1)
|
||||
|
||||
|
||||
enforce_rms_period = 2
|
||||
scale_norm_deriv = (scale_deriv / scale_denom)
|
||||
if step % enforce_rms_period == 0:
|
||||
# Occasionally enforce the min/max-rms constraint.
|
||||
param_rms = (p ** 2).mean().sqrt()
|
||||
is_too_small = (param_rms < param_min_rms)
|
||||
is_too_large = (param_rms > param_max_rms)
|
||||
scale_norm_deriv.masked_fill_(is_too_small,
|
||||
-(n_cached_grads ** 0.5) * enforce_rms_period)
|
||||
scale_norm_deriv.masked_fill_(is_too_large,
|
||||
(n_cached_grads ** 0.5) * enforce_rms_period)
|
||||
#print(f"param_rms={param_rms}, param_min,max_rms={param_min_rms},{param_max_rms} is_too_small={is_too_small},is_too_large={is_too_large},scale_norm_deriv={scale_norm_deriv}")
|
||||
|
||||
# scale_delta is going to be the change in log-scale (before momentum), i.e. if
|
||||
# p = underlying_param * scale.exp(),
|
||||
# delta is the change in `scale`.
|
||||
scale_delta = scale_alpha * (scale_deriv / scale_denom)
|
||||
#if random.random() < 0.01:
|
||||
# print("scale_delta = ", scale_delta)
|
||||
|
||||
scale_delta = scale_alpha * scale_norm_deriv
|
||||
delta.add_(p, alpha=scale_delta)
|
||||
|
||||
exp_avg_sq = state["exp_avg_sq"]
|
||||
@ -416,13 +432,7 @@ class NeutralGradient(Optimizer):
|
||||
|
||||
if random.random() < 0.0001:
|
||||
print(f"Delta rms = {(delta**2).mean().item()}, shape = {delta.shape}")
|
||||
#max_abs = delta.abs().max()
|
||||
#if max_abs > 1.0:
|
||||
# print("Max_abs_change = ", max_abs)
|
||||
p.add_(delta)
|
||||
if step % 10 == 0:
|
||||
p.clamp_(min=-param_max, max=param_max)
|
||||
|
||||
|
||||
state["step"] += 1
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user