Loosen limit on param_max_rms, from 2.0 to 3.0; change how param_min_rms is applied.

This commit is contained in:
Daniel Povey 2022-09-20 15:20:43 +08:00
parent eb77fa7aaa
commit e2fdfe990c

View File

@ -78,7 +78,7 @@ class ScaledAdam(Optimizer):
scalar_lr_scale=0.1,
eps=1.0e-08,
param_min_rms=1.0e-05,
param_max_rms=2.0,
param_max_rms=3.0,
scalar_max=2.0,
size_update_period=4,
clipping_update_period=100,
@ -272,7 +272,6 @@ class ScaledAdam(Optimizer):
"""
lr = group["lr"]
size_update_period = group["size_update_period"]
eps = group["eps"]
beta1 = group["betas"][0]
grad = p.grad
@ -289,7 +288,7 @@ class ScaledAdam(Optimizer):
scale_grads[step % size_update_period] = (p * grad).sum()
if step % size_update_period == size_update_period - 1:
param_rms = state["param_rms"]
param_rms.copy_((p**2).mean().sqrt().clamp_(min=eps))
param_rms.copy_((p**2).mean().sqrt())
if step > 0:
# self._size_update() learns the overall scale on the
# parameter, by shrinking or expanding it.
@ -330,6 +329,7 @@ class ScaledAdam(Optimizer):
size_lr = group["lr"] * group["scalar_lr_scale"]
param_min_rms = group["param_min_rms"]
param_max_rms = group["param_max_rms"]
eps = group["eps"]
step = state["step"]
size_update_period = scale_grads.shape[0]
@ -348,7 +348,6 @@ class ScaledAdam(Optimizer):
# we don't bother with bias_correction1; this will help prevent divergence
# at the start of training.
eps = 1.0e-10 # this value is not critical.
denom = scale_exp_avg_sq.sqrt() + eps
scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum() / denom
@ -356,7 +355,9 @@ class ScaledAdam(Optimizer):
is_too_small = (param_rms < param_min_rms)
is_too_large = (param_rms > param_max_rms)
scale_step.masked_fill_(is_too_small, size_lr * size_update_period)
# when the param gets too small, just don't shrink it any further.
scale_step.masked_fill_(is_too_small, 0.0)
# when it gets too large, stop it from getting any larger.
scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
delta = state["delta"]
# the factor of (1-beta1) relates to momentum.
@ -383,8 +384,10 @@ class ScaledAdam(Optimizer):
lr = group["lr"]
beta1, beta2 = group["betas"]
eps = group["eps"]
param_min_rms = group["param_min_rms"]
step = state["step"]
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
value=(1-beta2))
@ -399,8 +402,7 @@ class ScaledAdam(Optimizer):
denom += eps
grad = grad / denom
alpha = -lr * (1-beta1) * state["param_rms"]
alpha = -lr * (1-beta1) * state["param_rms"].clamp(min=param_min_rms)
delta = state["delta"]
delta.add_(grad * alpha)