This commit is contained in:
Daniel Povey 2022-07-10 10:20:39 +08:00
parent b7035844a2
commit d9a6180ae0

View File

@ -123,6 +123,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
lr = group["lr"]
size_lr = lr * group["size_lr_scale"]
beta1, beta2 = group["betas"]
scalar_max = group["scalar_max"]
eps = group["eps"]
size_update_period = group["size_update_period"]
param_min_rms = group["param_min_rms"]
@ -231,7 +232,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# For parameters with very few elements we just use a form
# of Adam with a scale factor to reflect the overall
# parameter rms. Updates delta.
self._step_scalar(beta1, beta2, eps, lr, p, grad, state)
self._step_scalar(scalar_max, beta1, beta2, eps, lr, p, grad, state)
else:
if step % lr_update_period == 0 and step > 0:
self._accum_param_covs(group, p, state)
@ -595,6 +596,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
def _step_scalar(self,
scalar_max: float,
beta1: float,
beta2: float,
eps: float,
@ -618,7 +620,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
delta = state["delta"]
delta.add_(grad / denom, alpha=alpha)
scalar_max = state["scalar_max"]
p.clamp_(min=-scalar_max, max=scalar_max)
def _project(self,