diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index e8c9f6c0e..62d30e65c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -68,6 +68,8 @@ class NeutralGradient(Optimizer): to the parameter rms) and 0.0 means no such normalization at all (only a scalar per tensor). In any case we fully normalize the parameter scale at the whole-tensor level, for non-scalar tensors. + grad_min_rand: Minimum proportion of random tensor that we add to the gradient covariance, + to make the gradient covariance have less effect. lr_for_speedup: A learning rate that determines how much speedup we aim for by accumulating gradients over multiple batches (the update is done after different delays for different parameters, based on an estimate of their @@ -91,6 +93,7 @@ class NeutralGradient(Optimizer): estimate_period=2000, stats_steps=200, param_pow=1.0, + grad_min_rand=0.2, lr_for_speedup=0.03, speedup_recalibrate_period=100, speedup_first_step=500, @@ -136,6 +139,7 @@ class NeutralGradient(Optimizer): estimate_period=estimate_period, stats_steps=stats_steps, param_pow=param_pow, + grad_min_rand=grad_min_rand, lr_for_speedup=lr_for_speedup, speedup_recalibrate_period=speedup_recalibrate_period, speedup_first_step=speedup_first_step, @@ -172,6 +176,7 @@ class NeutralGradient(Optimizer): estimate_period = group["estimate_period"] stats_steps = group["stats_steps"] param_pow = group["param_pow"] + grad_min_rand = group["grad_min_rand"] lr_for_speedup = group["lr_for_speedup"] speedup_first_step = group["speedup_first_step"] speedup_recalibrate_period = group["speedup_recalibrate_period"] @@ -349,7 +354,8 @@ class NeutralGradient(Optimizer): # The full update. step_within_period = state["step_within_period"] if step_within_period == estimate_period: - self._estimate_projections(p, state, param_eps, param_rel_eps, param_pow) + self._estimate_projections(p, state, param_eps, param_rel_eps, + param_pow, grad_min_rand) state["step_within_period"] = 0 elif step_within_period >= estimate_period - stats_steps: self._store_grad_stats(grad, state, max_fullcov_size) @@ -459,7 +465,9 @@ class NeutralGradient(Optimizer): state: dict, param_eps: float, param_rel_eps: float, - param_pow: float) -> None: + param_pow: float, + grad_min_rand: float, + ) -> None: """ Estimate the per-dimension projections proj_0, proj_1 and so on. These projections, when applied with forward==True by _change_coordinates(), @@ -484,6 +492,8 @@ class NeutralGradient(Optimizer): parameter variance averaged over the entire tensor. param_pow: param_pow==1.0 means fully normalizing for parameter scale; setting to a smaller value closer to 0.0 means partial normalization. + grad_min_rand: minimum proportion of random tensor to mix with the + gradient covariance matrix. """ kwargs = {'device':p.device, 'dtype':p.dtype} @@ -503,6 +513,7 @@ class NeutralGradient(Optimizer): del state[f"grad_cov_{dim}"] # save memory self._randomize_lowrank_cov(grad_cov, count, p.shape) + param_cov = self._get_param_cov(p, dim) # P is the SPD matrix such that P G P^T == C^{param_pow}, @@ -669,7 +680,8 @@ class NeutralGradient(Optimizer): return param_cov - def _randomize_lowrank_cov(self, cov: Tensor, rank: int, shape: torch.Size) -> None: + def _randomize_lowrank_cov(self, cov: Tensor, rank: int, shape: torch.Size, + min_rand: float = 0.0) -> None: """ If this covariance has too-low rank (based on our knowledge of how we computed it), add to it a random positive-semidefinite matrix to make sure it is full-rank. @@ -703,7 +715,7 @@ class NeutralGradient(Optimizer): # We are trying to find the trace of the matrix we need to add, # i.e. how big it is relative to the trace of the existing matrix. missing_rank = (required_rank - rank) - rand_scale = 0.2 * missing_rank / rank + rand_scale = max(min_rand, 0.2 * missing_rank / rank) R = torch.randn(size, size, device=cov.device, dtype=cov.dtype) R = torch.matmul(R, R.t()) # positive semidefinite random matrix R_scale = R.diag().mean() + 1.0e-20