Use scalar_lr_scale for scalars as well as sizes.

This commit is contained in:
Daniel Povey 2022-09-18 14:12:27 +08:00
parent 76031a7c1d
commit 69404f61ef

View File

@ -50,9 +50,10 @@ class ScaledAdam(Optimizer):
by this quantity. by this quantity.
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
Must satisfy 0 < beta <= beta2 < 1. Must satisfy 0 < beta <= beta2 < 1.
size_lr_scale: A scaling factor on the learning rate, that we use to update the scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
scale of each parameter tensor. If each parameter were decomposed scale of each parameter tensor and scalar parameters of the mode..
as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, size_lr_scale If each parameter were decomposed
as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
would be a the scaling factor on the learning rate of p_scale. would be a the scaling factor on the learning rate of p_scale.
eps: A general-purpose epsilon to prevent division by zero eps: A general-purpose epsilon to prevent division by zero
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
@ -74,7 +75,7 @@ class ScaledAdam(Optimizer):
lr=3e-02, lr=3e-02,
clipping_scale=None, clipping_scale=None,
betas=(0.9, 0.98), betas=(0.9, 0.98),
size_lr_scale=0.1, scalar_lr_scale=0.1,
eps=1.0e-08, eps=1.0e-08,
param_min_rms=1.0e-05, param_min_rms=1.0e-05,
param_max_rms=2.0, param_max_rms=2.0,
@ -89,7 +90,7 @@ class ScaledAdam(Optimizer):
lr=lr, lr=lr,
clipping_scale=clipping_scale, clipping_scale=clipping_scale,
betas=betas, betas=betas,
size_lr_scale=size_lr_scale, scalar_lr_scale=scalar_lr_scale,
eps=eps, eps=eps,
param_min_rms=param_min_rms, param_min_rms=param_min_rms,
param_max_rms=param_max_rms, param_max_rms=param_max_rms,
@ -326,7 +327,7 @@ class ScaledAdam(Optimizer):
param_rms = state["param_rms"] param_rms = state["param_rms"]
beta1, beta2 = group["betas"] beta1, beta2 = group["betas"]
size_lr = group["lr"] * group["size_lr_scale"] size_lr = group["lr"] * group["scalar_lr_scale"]
param_min_rms = group["param_min_rms"] param_min_rms = group["param_min_rms"]
param_max_rms = group["param_max_rms"] param_max_rms = group["param_max_rms"]
step = state["step"] step = state["step"]
@ -417,7 +418,7 @@ class ScaledAdam(Optimizer):
beta1, beta2 = group["betas"] beta1, beta2 = group["betas"]
scalar_max = group["scalar_max"] scalar_max = group["scalar_max"]
eps = group["eps"] eps = group["eps"]
lr = group["lr"] lr = group["lr"] * group["scalar_lr_scale"]
grad = p.grad grad = p.grad
exp_avg_sq = state["exp_avg_sq"] # scalar exp_avg_sq = state["exp_avg_sq"] # scalar