Use scalar_lr_scale for scalars as well as sizes.

This commit is contained in:
Daniel Povey 2022-09-18 14:12:27 +08:00
parent 76031a7c1d
commit 69404f61ef

View File

@ -50,9 +50,10 @@ class ScaledAdam(Optimizer):
by this quantity.
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
Must satisfy 0 < beta <= beta2 < 1.
size_lr_scale: A scaling factor on the learning rate, that we use to update the
scale of each parameter tensor. If each parameter were decomposed
as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, size_lr_scale
scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
scale of each parameter tensor and scalar parameters of the mode..
If each parameter were decomposed
as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
would be a the scaling factor on the learning rate of p_scale.
eps: A general-purpose epsilon to prevent division by zero
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
@ -74,7 +75,7 @@ class ScaledAdam(Optimizer):
lr=3e-02,
clipping_scale=None,
betas=(0.9, 0.98),
size_lr_scale=0.1,
scalar_lr_scale=0.1,
eps=1.0e-08,
param_min_rms=1.0e-05,
param_max_rms=2.0,
@ -89,7 +90,7 @@ class ScaledAdam(Optimizer):
lr=lr,
clipping_scale=clipping_scale,
betas=betas,
size_lr_scale=size_lr_scale,
scalar_lr_scale=scalar_lr_scale,
eps=eps,
param_min_rms=param_min_rms,
param_max_rms=param_max_rms,
@ -326,7 +327,7 @@ class ScaledAdam(Optimizer):
param_rms = state["param_rms"]
beta1, beta2 = group["betas"]
size_lr = group["lr"] * group["size_lr_scale"]
size_lr = group["lr"] * group["scalar_lr_scale"]
param_min_rms = group["param_min_rms"]
param_max_rms = group["param_max_rms"]
step = state["step"]
@ -417,7 +418,7 @@ class ScaledAdam(Optimizer):
beta1, beta2 = group["betas"]
scalar_max = group["scalar_max"]
eps = group["eps"]
lr = group["lr"]
lr = group["lr"] * group["scalar_lr_scale"]
grad = p.grad
exp_avg_sq = state["exp_avg_sq"] # scalar