Make LR update period less frequent later in training; fix bug with param_cov freshness, was too fresh

This commit is contained in:
Daniel Povey 2022-07-15 07:59:30 +08:00
parent 689441b237
commit b6ee698278

View File

@ -127,7 +127,11 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
scalar_max: Maximum absolute value for scalar parameters
size_update_period: The periodicity, in steps, with which we update the size (scale)
of the parameter tensor. This is provided to save a little time.
lr_update_period: The periodicity, in steps, with which we update the learning-rate matrices.
lr_update_period: Determines the periodicity, in steps, with which we update the
learning-rate matrices. The first number is the periodicity at
the start of training, the second number is the periodicity
later in training. One step of updating the learning rate matrices
can take as long as over 50 minibatches, because SVD on GPU is slow.
** This is important for the speed/optimizaton tradeoff. **
param_cov_period: The periodicity, in steps, with which we update the parameter covariance
stats.
@ -148,7 +152,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
param_max_rms=2.0,
scalar_max=2.0,
size_update_period=4,
lr_update_period=200,
lr_update_period=(200, 2000),
grad_cov_period=3,
param_cov_period=100,
max_block_size=1024,
@ -350,7 +354,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
"""
lr = group["lr"]
size_update_period = group["size_update_period"]
lr_update_period = group["lr_update_period"]
grad_cov_period = group["grad_cov_period"]
param_cov_period = group["param_cov_period"]
eps = group["eps"]
@ -386,7 +389,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
else:
if step % param_cov_period == 0:
self._update_param_cov(group, p, state)
if step % lr_update_period == 0 and step > 0 and "zero_step" in state:
if self._is_lr_update_step(group, state):
self._update_lrs(group, p, state)
self._zero_exp_avg_sq(state)
if step % grad_cov_period == 0:
@ -463,10 +466,11 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
(except batch and trivial and rank-1 dims)
"""
eps = group["eps"]
lr_update_period = group["lr_update_period"]
param_cov_period = group["param_cov_period"]
step = state["step"]
this_weight = (group["param_cov_freshness"] * lr_update_period /
(step + lr_update_period))
this_weight = (group["param_cov_freshness"] * param_cov_period /
(step + param_cov_period))
batch_size = p.shape[0]
numel = p.numel() // batch_size
@ -507,6 +511,30 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
state["exp_avg_sq"].zero_()
state["zero_step"] = state["step"]
def _is_lr_update_step(self,
group: dict,
state: dict) -> False:
"""
Returns True if on this step we need to update the learning-rate matrices for this tensor
and False if not. The periodicity with which we update them increases from
(by default) 200 at the start of training to 2000 later on.
"""
try:
zero_step = state["zero_step"]
except:
# This parameter tensor has no learning-rate matrices to estimate
return False
step = state["step"]
period_initial, period_final = group["lr_update_period"]
# this formula gradually increases the periodicity from period_initial at the start
# to period_final when step >> 4 * period_final
cur_update_period = (period_initial +
((period_final - period_initial) * step /
(step + 4 * period_final)))
return (step >= zero_step + cur_update_period)
def _update_lrs(self,
group: dict,
p: Tensor,