introduce grad_cov_period

This commit is contained in:
Daniel Povey 2022-07-09 10:29:23 +08:00
parent 35a51bc153
commit 61cab3ab65

View File

@ -77,6 +77,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
param_max_rms=2.0,
size_update_period=4,
lr_update_period=20,
grad_cov_period=3,
):
@ -93,6 +94,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
param_max_rms=param_max_rms,
size_update_period=size_update_period,
lr_update_period=lr_update_period,
grad_cov_period=grad_cov_period,
)
super(PrAdam, self).__init__(params, defaults)
@ -463,6 +465,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
lr = group["lr"]
beta1, beta2 = group["betas"]
eps = group["eps"]
grad_cov_period = group["grad_cov_period"]
step = state["step"]
for dim in range(grad.ndim):
@ -472,12 +476,13 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
size = grad.shape[dim]
if size == 1:
continue
grad_cov = state[f"grad_cov_{dim}"]
this_m = p.transpose(-1, dim).reshape(-1, size) # parameter matrix M
this_g = grad.transpose(-1, dim).reshape(-1, size) # M_grad
# could perhaps accumulate grad_cov less frequently; it's only
# needed when we rediagonalize which is not that common.
grad_cov.mul_(beta2).add_(torch.matmul(this_g.t(), this_g))
if step % grad_cov_period == 0 or step < grad_cov_period:
grad_cov = state[f"grad_cov_{dim}"]
this_m = p.transpose(-1, dim).reshape(-1, size) # parameter matrix M
this_g = grad.transpose(-1, dim).reshape(-1, size) # M_grad
# could perhaps accumulate grad_cov less frequently; it's only
# needed when we rediagonalize which is not that common.
grad_cov.mul_(beta2).add_(torch.matmul(this_g.t(), this_g))
grad = self._project(grad, state, forward=True)