diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 58e9b3456..84a2d7811 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -127,7 +127,11 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of scalar_max: Maximum absolute value for scalar parameters size_update_period: The periodicity, in steps, with which we update the size (scale) of the parameter tensor. This is provided to save a little time. - lr_update_period: The periodicity, in steps, with which we update the learning-rate matrices. + lr_update_period: Determines the periodicity, in steps, with which we update the + learning-rate matrices. The first number is the periodicity at + the start of training, the second number is the periodicity + later in training. One step of updating the learning rate matrices + can take as long as over 50 minibatches, because SVD on GPU is slow. ** This is important for the speed/optimizaton tradeoff. ** param_cov_period: The periodicity, in steps, with which we update the parameter covariance stats. @@ -148,7 +152,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of param_max_rms=2.0, scalar_max=2.0, size_update_period=4, - lr_update_period=200, + lr_update_period=(200, 2000), grad_cov_period=3, param_cov_period=100, max_block_size=1024, @@ -350,7 +354,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of """ lr = group["lr"] size_update_period = group["size_update_period"] - lr_update_period = group["lr_update_period"] grad_cov_period = group["grad_cov_period"] param_cov_period = group["param_cov_period"] eps = group["eps"] @@ -386,7 +389,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of else: if step % param_cov_period == 0: self._update_param_cov(group, p, state) - if step % lr_update_period == 0 and step > 0 and "zero_step" in state: + if self._is_lr_update_step(group, state): self._update_lrs(group, p, state) self._zero_exp_avg_sq(state) if step % grad_cov_period == 0: @@ -463,10 +466,11 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of (except batch and trivial and rank-1 dims) """ eps = group["eps"] - lr_update_period = group["lr_update_period"] + param_cov_period = group["param_cov_period"] step = state["step"] - this_weight = (group["param_cov_freshness"] * lr_update_period / - (step + lr_update_period)) + + this_weight = (group["param_cov_freshness"] * param_cov_period / + (step + param_cov_period)) batch_size = p.shape[0] numel = p.numel() // batch_size @@ -507,6 +511,30 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of state["exp_avg_sq"].zero_() state["zero_step"] = state["step"] + def _is_lr_update_step(self, + group: dict, + state: dict) -> False: + """ + Returns True if on this step we need to update the learning-rate matrices for this tensor + and False if not. The periodicity with which we update them increases from + (by default) 200 at the start of training to 2000 later on. + """ + try: + zero_step = state["zero_step"] + except: + # This parameter tensor has no learning-rate matrices to estimate + return False + step = state["step"] + period_initial, period_final = group["lr_update_period"] + # this formula gradually increases the periodicity from period_initial at the start + # to period_final when step >> 4 * period_final + cur_update_period = (period_initial + + ((period_final - period_initial) * step / + (step + 4 * period_final))) + return (step >= zero_step + cur_update_period) + + + def _update_lrs(self, group: dict, p: Tensor,