mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Loosen limit on param_max_rms, from 2.0 to 3.0; change how param_min_rms is applied.
This commit is contained in:
parent
eb77fa7aaa
commit
e2fdfe990c
@ -78,7 +78,7 @@ class ScaledAdam(Optimizer):
|
||||
scalar_lr_scale=0.1,
|
||||
eps=1.0e-08,
|
||||
param_min_rms=1.0e-05,
|
||||
param_max_rms=2.0,
|
||||
param_max_rms=3.0,
|
||||
scalar_max=2.0,
|
||||
size_update_period=4,
|
||||
clipping_update_period=100,
|
||||
@ -272,7 +272,6 @@ class ScaledAdam(Optimizer):
|
||||
"""
|
||||
lr = group["lr"]
|
||||
size_update_period = group["size_update_period"]
|
||||
eps = group["eps"]
|
||||
beta1 = group["betas"][0]
|
||||
|
||||
grad = p.grad
|
||||
@ -289,7 +288,7 @@ class ScaledAdam(Optimizer):
|
||||
scale_grads[step % size_update_period] = (p * grad).sum()
|
||||
if step % size_update_period == size_update_period - 1:
|
||||
param_rms = state["param_rms"]
|
||||
param_rms.copy_((p**2).mean().sqrt().clamp_(min=eps))
|
||||
param_rms.copy_((p**2).mean().sqrt())
|
||||
if step > 0:
|
||||
# self._size_update() learns the overall scale on the
|
||||
# parameter, by shrinking or expanding it.
|
||||
@ -330,6 +329,7 @@ class ScaledAdam(Optimizer):
|
||||
size_lr = group["lr"] * group["scalar_lr_scale"]
|
||||
param_min_rms = group["param_min_rms"]
|
||||
param_max_rms = group["param_max_rms"]
|
||||
eps = group["eps"]
|
||||
step = state["step"]
|
||||
|
||||
size_update_period = scale_grads.shape[0]
|
||||
@ -348,7 +348,6 @@ class ScaledAdam(Optimizer):
|
||||
# we don't bother with bias_correction1; this will help prevent divergence
|
||||
# at the start of training.
|
||||
|
||||
eps = 1.0e-10 # this value is not critical.
|
||||
denom = scale_exp_avg_sq.sqrt() + eps
|
||||
|
||||
scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum() / denom
|
||||
@ -356,7 +355,9 @@ class ScaledAdam(Optimizer):
|
||||
is_too_small = (param_rms < param_min_rms)
|
||||
is_too_large = (param_rms > param_max_rms)
|
||||
|
||||
scale_step.masked_fill_(is_too_small, size_lr * size_update_period)
|
||||
# when the param gets too small, just don't shrink it any further.
|
||||
scale_step.masked_fill_(is_too_small, 0.0)
|
||||
# when it gets too large, stop it from getting any larger.
|
||||
scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
|
||||
delta = state["delta"]
|
||||
# the factor of (1-beta1) relates to momentum.
|
||||
@ -383,8 +384,10 @@ class ScaledAdam(Optimizer):
|
||||
lr = group["lr"]
|
||||
beta1, beta2 = group["betas"]
|
||||
eps = group["eps"]
|
||||
param_min_rms = group["param_min_rms"]
|
||||
step = state["step"]
|
||||
|
||||
|
||||
exp_avg_sq = state["exp_avg_sq"]
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
||||
value=(1-beta2))
|
||||
@ -399,8 +402,7 @@ class ScaledAdam(Optimizer):
|
||||
denom += eps
|
||||
grad = grad / denom
|
||||
|
||||
|
||||
alpha = -lr * (1-beta1) * state["param_rms"]
|
||||
alpha = -lr * (1-beta1) * state["param_rms"].clamp(min=param_min_rms)
|
||||
|
||||
delta = state["delta"]
|
||||
delta.add_(grad * alpha)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user