diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 6143a6fe3..61be0ddb7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -77,6 +77,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of param_max_rms=2.0, size_update_period=4, lr_update_period=20, + grad_cov_period=3, ): @@ -93,6 +94,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of param_max_rms=param_max_rms, size_update_period=size_update_period, lr_update_period=lr_update_period, + grad_cov_period=grad_cov_period, ) super(PrAdam, self).__init__(params, defaults) @@ -463,6 +465,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of lr = group["lr"] beta1, beta2 = group["betas"] eps = group["eps"] + grad_cov_period = group["grad_cov_period"] + step = state["step"] for dim in range(grad.ndim): @@ -472,12 +476,13 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of size = grad.shape[dim] if size == 1: continue - grad_cov = state[f"grad_cov_{dim}"] - this_m = p.transpose(-1, dim).reshape(-1, size) # parameter matrix M - this_g = grad.transpose(-1, dim).reshape(-1, size) # M_grad - # could perhaps accumulate grad_cov less frequently; it's only - # needed when we rediagonalize which is not that common. - grad_cov.mul_(beta2).add_(torch.matmul(this_g.t(), this_g)) + if step % grad_cov_period == 0 or step < grad_cov_period: + grad_cov = state[f"grad_cov_{dim}"] + this_m = p.transpose(-1, dim).reshape(-1, size) # parameter matrix M + this_g = grad.transpose(-1, dim).reshape(-1, size) # M_grad + # could perhaps accumulate grad_cov less frequently; it's only + # needed when we rediagonalize which is not that common. + grad_cov.mul_(beta2).add_(torch.matmul(this_g.t(), this_g)) grad = self._project(grad, state, forward=True)