From 11eac9089eed742a97596397c4a947160c8ea8ad Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 20 May 2022 23:05:05 +0800 Subject: [PATCH] Fix w.r.t. how cov stats are estimated: storing more stats, and now always zeroing the stats. --- .../ASR/pruned_transducer_stateless4/optim.py | 23 ++++++++----------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py index 74d6bf569..9db522b31 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py @@ -666,20 +666,17 @@ class Cain(Optimizer): ndim = grad.ndim numel = grad.numel() step = state["step"] - for i in range(grad.ndim): - - dim = (i + step) % grad.ndim - + for dim in range(grad.ndim): # see for each dim in turn whether we want to perform any changes in co-ordinates, # or store any stats. size = grad.shape[dim] if size <= 3 or size % 2 == 1 or size >= 2048 or size == numel: - # we don't do any such co-ordinate changes in dims that are too small (no point) - # or large (too slow), or convolutional (odd). can revisit this later. + # we don't do any such co-ordinate changes in dims with sizes + # that are too small (no point) or large (too slow), or that are + # assumed convolutional (because they are odd). We can revisit + # this later. continue - allow_store_stats = (i == grad.ndim - 1) - grad = self._change_coordinates_for_dim(grad, state, dim, forward, - allow_store_stats) + grad = self._change_coordinates_for_dim(grad, state, dim, forward) return grad @@ -688,7 +685,6 @@ class Cain(Optimizer): state: dict, dim: int, forward: bool, - allow_store_stats: bool, orig_dim: int = -1): assert grad.ndim > 1 if not (grad.ndim == 2 and dim == 1): @@ -710,7 +706,7 @@ class Cain(Optimizer): new_shape = new_grad.shape new_grad = new_grad.reshape(-1, new_grad.shape[-1]) new_grad = self._change_coordinates_for_dim(new_grad, state, 1, forward, - allow_store_stats, orig_dim=dim) + orig_dim=dim) return new_grad.reshape(new_shape).permute(*rev_dims_order) # OK: grad.ndim == 2 and dim == 1 @@ -719,8 +715,8 @@ class Cain(Optimizer): orig_dim = dim step = state["step"] must_store_stats, must_zero_stats = self._must_store_stats(step) - if must_store_stats and allow_store_stats and forward: - # store stats for 100 iters preceding when we estimate the + if must_store_stats and forward: + # store stats for 200 iters preceding when we estimate the # transform. stats_name = f"cov_{orig_dim}" if not stats_name in state: @@ -728,7 +724,6 @@ class Cain(Optimizer): device=grad.device) cov = state[stats_name] if must_zero_stats: - #print("zero") cov.zero_() cov += torch.matmul(grad.t(), grad) * (1/grad.shape[0])