This commit is contained in:
Daniel Povey 2022-07-08 04:44:21 +08:00
parent 04d2e10b4f
commit ad2e698fc3

View File

@ -54,9 +54,6 @@ class LearnedGradient(Optimizer):
diagonalize_period: How often we re-diagonalize the gradient covariance; this
gets multiplied by lr_est_period, i.e. we diagonalize less frequently
than we update the learning-rate matrixes.
max_step_scale: Value that determines limit on how large steps can be per element,
intended to prevent instability during early steps before the learning-rate
matrices are well learned.
"""
def __init__(
self,
@ -71,9 +68,8 @@ class LearnedGradient(Optimizer):
param_max_rms=2.0,
lr_mat_min=0.01,
lr_mat_max=4.0,
lr_est_period=10,
lr_est_period=5,
diagonalize_period=4,
max_step_scale=2.0
):
@ -90,7 +86,6 @@ class LearnedGradient(Optimizer):
lr_mat_max=lr_mat_max,
lr_est_period=lr_est_period,
diagonalize_period=diagonalize_period,
max_step_scale=max_step_scale,
)
super(LearnedGradient, self).__init__(params, defaults)
@ -125,7 +120,6 @@ class LearnedGradient(Optimizer):
param_max_rms = group["param_max_rms"]
lr_est_period = group["lr_est_period"]
diagonalize_period = group["diagonalize_period"]
max_step_scale = group["max_step_scale"]
for p in group["params"]:
if p.grad is None:
@ -232,6 +226,8 @@ class LearnedGradient(Optimizer):
step = state["step"]
delta = state["delta"]
delta.mul_(beta1)
numel = p.numel()
if numel > 1:
# Update the size/scale of p, and set param_rms
@ -248,8 +244,7 @@ class LearnedGradient(Optimizer):
beta1, beta2, step, size_lr,
param_min_rms, param_max_rms)
delta = state["delta"]
delta.mul_(beta1)
if numel == 1:
# For parameters with very few elements we just use a form
# of Adam with a scale factor to reflect the overall
@ -482,8 +477,6 @@ class LearnedGradient(Optimizer):
M_delta = torch.matmul(M, proj_delta) # (eq.5), M_delta = M proj_delta
alpha = -meta_lr * (1-beta1)
M_delta_full = M_delta.reshape(M_full.shape)
this_delta = M_delta_full.transpose(dim, -1)