Bug fix
This commit is contained in:
parent
b7035844a2
commit
d9a6180ae0
@ -123,6 +123,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
lr = group["lr"]
|
||||
size_lr = lr * group["size_lr_scale"]
|
||||
beta1, beta2 = group["betas"]
|
||||
scalar_max = group["scalar_max"]
|
||||
eps = group["eps"]
|
||||
size_update_period = group["size_update_period"]
|
||||
param_min_rms = group["param_min_rms"]
|
||||
@ -231,7 +232,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# For parameters with very few elements we just use a form
|
||||
# of Adam with a scale factor to reflect the overall
|
||||
# parameter rms. Updates delta.
|
||||
self._step_scalar(beta1, beta2, eps, lr, p, grad, state)
|
||||
self._step_scalar(scalar_max, beta1, beta2, eps, lr, p, grad, state)
|
||||
else:
|
||||
if step % lr_update_period == 0 and step > 0:
|
||||
self._accum_param_covs(group, p, state)
|
||||
@ -595,6 +596,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
|
||||
|
||||
def _step_scalar(self,
|
||||
scalar_max: float,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
eps: float,
|
||||
@ -618,7 +620,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
|
||||
delta = state["delta"]
|
||||
delta.add_(grad / denom, alpha=alpha)
|
||||
scalar_max = state["scalar_max"]
|
||||
p.clamp_(min=-scalar_max, max=scalar_max)
|
||||
|
||||
def _project(self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user