Increase lr_update_period to 200. The update takes about 2 minutes, fore entire model.

This commit is contained in:
Daniel Povey 2022-07-09 11:36:54 +08:00
parent 61cab3ab65
commit 209acaf6e4

View File

@ -76,7 +76,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
param_min_rms=1.0e-05, param_min_rms=1.0e-05,
param_max_rms=2.0, param_max_rms=2.0,
size_update_period=4, size_update_period=4,
lr_update_period=20, lr_update_period=200,
grad_cov_period=3, grad_cov_period=3,
): ):
@ -205,9 +205,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs) state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs)
step = state["step"] step = state["step"]
delta = state["delta"] delta = state["delta"]
delta.mul_(beta1) delta.mul_(beta1)
numel = p.numel() numel = p.numel()