Increase lr_update_period to 200. The update takes about 2 minutes, fore entire model.

This commit is contained in:
Daniel Povey 2022-07-09 11:36:54 +08:00
parent 61cab3ab65
commit 209acaf6e4

View File

@ -76,7 +76,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
param_min_rms=1.0e-05,
param_max_rms=2.0,
size_update_period=4,
lr_update_period=20,
lr_update_period=200,
grad_cov_period=3,
):
@ -190,7 +190,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# param_cov_{dim} is the averaged-over-time gradient of parameters on this dimension, treating
# all other dims as a batch axis.
state[f"param_cov_{dim}"] = torch.zeros(size, size, **kwargs)
state[f"param_cov_{dim}"] = torch.zeros(size, size, **kwargs)
# grad_cov_{dim} is the covariance of gradients on this axis (without
# any co-ordinate changes), treating all other axes as as a batch axis.
@ -205,9 +205,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs)
step = state["step"]
delta = state["delta"]
delta.mul_(beta1)
numel = p.numel()