mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
introduce grad_cov_period
This commit is contained in:
parent
35a51bc153
commit
61cab3ab65
@ -77,6 +77,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
param_max_rms=2.0,
|
||||
size_update_period=4,
|
||||
lr_update_period=20,
|
||||
grad_cov_period=3,
|
||||
):
|
||||
|
||||
|
||||
@ -93,6 +94,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
param_max_rms=param_max_rms,
|
||||
size_update_period=size_update_period,
|
||||
lr_update_period=lr_update_period,
|
||||
grad_cov_period=grad_cov_period,
|
||||
)
|
||||
|
||||
super(PrAdam, self).__init__(params, defaults)
|
||||
@ -463,6 +465,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
lr = group["lr"]
|
||||
beta1, beta2 = group["betas"]
|
||||
eps = group["eps"]
|
||||
grad_cov_period = group["grad_cov_period"]
|
||||
step = state["step"]
|
||||
|
||||
|
||||
for dim in range(grad.ndim):
|
||||
@ -472,12 +476,13 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
size = grad.shape[dim]
|
||||
if size == 1:
|
||||
continue
|
||||
grad_cov = state[f"grad_cov_{dim}"]
|
||||
this_m = p.transpose(-1, dim).reshape(-1, size) # parameter matrix M
|
||||
this_g = grad.transpose(-1, dim).reshape(-1, size) # M_grad
|
||||
# could perhaps accumulate grad_cov less frequently; it's only
|
||||
# needed when we rediagonalize which is not that common.
|
||||
grad_cov.mul_(beta2).add_(torch.matmul(this_g.t(), this_g))
|
||||
if step % grad_cov_period == 0 or step < grad_cov_period:
|
||||
grad_cov = state[f"grad_cov_{dim}"]
|
||||
this_m = p.transpose(-1, dim).reshape(-1, size) # parameter matrix M
|
||||
this_g = grad.transpose(-1, dim).reshape(-1, size) # M_grad
|
||||
# could perhaps accumulate grad_cov less frequently; it's only
|
||||
# needed when we rediagonalize which is not that common.
|
||||
grad_cov.mul_(beta2).add_(torch.matmul(this_g.t(), this_g))
|
||||
|
||||
grad = self._project(grad, state, forward=True)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user