diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py index 936baf658..9c4646009 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py @@ -598,19 +598,21 @@ class Cain(Optimizer): """ ndim = grad.ndim numel = grad.numel() - for dim in range(grad.ndim): + step = state["step"] + for i in range(grad.ndim): + + dim = (i + step) % grad.ndim + # see for each dim in turn whether we want to perform any changes in co-ordinates, # or store any stats. size = grad.shape[dim] - if size <= 3 or size > 1024 or size == numel: + if size <= 3 or size % 2 == 1 or size >= 2048 or size == numel: # we don't do any such co-ordinate changes in dims that are too small (no point) - # or large (too slow) + # or large (too slow), or convolutional (odd). can revisit this later. continue - if dim == 0: - # FOR NOW: don't do such co-ordinate changes on output - # dims, which will generally be dimension zero. We can revisit this later. - continue - grad = self._change_coordinates_for_dim(grad, state, dim, forward) + allow_store_stats = (i == grad.ndim - 1) + grad = self._change_coordinates_for_dim(grad, state, dim, forward, + allow_store_stats) return grad @@ -619,6 +621,7 @@ class Cain(Optimizer): state: dict, dim: int, forward: bool, + allow_store_stats: bool, orig_dim: int = -1): assert grad.ndim > 1 if not (grad.ndim == 2 and dim == 1): @@ -640,7 +643,7 @@ class Cain(Optimizer): new_shape = new_grad.shape new_grad = new_grad.reshape(-1, new_grad.shape[-1]) new_grad = self._change_coordinates_for_dim(new_grad, state, 1, forward, - orig_dim=dim) + allow_store_stats, orig_dim=dim) return new_grad.reshape(new_shape).permute(*rev_dims_order) # OK: grad.ndim == 2 and dim == 1 @@ -649,8 +652,8 @@ class Cain(Optimizer): orig_dim = dim step = state["step"] must_store_stats, must_zero_stats = self._must_store_stats(step) - if must_store_stats: - # store stats for 20 iters preceding when we estimate the + if must_store_stats and allow_store_stats: + # store stats for 100 iters preceding when we estimate the # transform. stats_name = f"cov_{orig_dim}" if not stats_name in state: @@ -695,11 +698,11 @@ class Cain(Optimizer): """ if step < 4000: if step < 2000: - return (step % 500 >= 480, step % 500 == 480) + return (step % 500 >= 400, step % 500 == 400) else: - return (step % 1000 >= 980, step % 1000 == 980) + return (step % 1000 >= 800, step % 1000 == 800) else: - return (step % 2000 >= 1980, step % 2000 == 1980) + return (step % 2000 >= 1800, step % 2000 == 1800) def _must_estimate_transform(self, step: int) -> bool: """