mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Make cain average over more iters and use preconditioning on the other dims first
This commit is contained in:
parent
ac5a9faafd
commit
6085ab64ef
@ -598,19 +598,21 @@ class Cain(Optimizer):
|
|||||||
"""
|
"""
|
||||||
ndim = grad.ndim
|
ndim = grad.ndim
|
||||||
numel = grad.numel()
|
numel = grad.numel()
|
||||||
for dim in range(grad.ndim):
|
step = state["step"]
|
||||||
|
for i in range(grad.ndim):
|
||||||
|
|
||||||
|
dim = (i + step) % grad.ndim
|
||||||
|
|
||||||
# see for each dim in turn whether we want to perform any changes in co-ordinates,
|
# see for each dim in turn whether we want to perform any changes in co-ordinates,
|
||||||
# or store any stats.
|
# or store any stats.
|
||||||
size = grad.shape[dim]
|
size = grad.shape[dim]
|
||||||
if size <= 3 or size > 1024 or size == numel:
|
if size <= 3 or size % 2 == 1 or size >= 2048 or size == numel:
|
||||||
# we don't do any such co-ordinate changes in dims that are too small (no point)
|
# we don't do any such co-ordinate changes in dims that are too small (no point)
|
||||||
# or large (too slow)
|
# or large (too slow), or convolutional (odd). can revisit this later.
|
||||||
continue
|
continue
|
||||||
if dim == 0:
|
allow_store_stats = (i == grad.ndim - 1)
|
||||||
# FOR NOW: don't do such co-ordinate changes on output
|
grad = self._change_coordinates_for_dim(grad, state, dim, forward,
|
||||||
# dims, which will generally be dimension zero. We can revisit this later.
|
allow_store_stats)
|
||||||
continue
|
|
||||||
grad = self._change_coordinates_for_dim(grad, state, dim, forward)
|
|
||||||
return grad
|
return grad
|
||||||
|
|
||||||
|
|
||||||
@ -619,6 +621,7 @@ class Cain(Optimizer):
|
|||||||
state: dict,
|
state: dict,
|
||||||
dim: int,
|
dim: int,
|
||||||
forward: bool,
|
forward: bool,
|
||||||
|
allow_store_stats: bool,
|
||||||
orig_dim: int = -1):
|
orig_dim: int = -1):
|
||||||
assert grad.ndim > 1
|
assert grad.ndim > 1
|
||||||
if not (grad.ndim == 2 and dim == 1):
|
if not (grad.ndim == 2 and dim == 1):
|
||||||
@ -640,7 +643,7 @@ class Cain(Optimizer):
|
|||||||
new_shape = new_grad.shape
|
new_shape = new_grad.shape
|
||||||
new_grad = new_grad.reshape(-1, new_grad.shape[-1])
|
new_grad = new_grad.reshape(-1, new_grad.shape[-1])
|
||||||
new_grad = self._change_coordinates_for_dim(new_grad, state, 1, forward,
|
new_grad = self._change_coordinates_for_dim(new_grad, state, 1, forward,
|
||||||
orig_dim=dim)
|
allow_store_stats, orig_dim=dim)
|
||||||
return new_grad.reshape(new_shape).permute(*rev_dims_order)
|
return new_grad.reshape(new_shape).permute(*rev_dims_order)
|
||||||
|
|
||||||
# OK: grad.ndim == 2 and dim == 1
|
# OK: grad.ndim == 2 and dim == 1
|
||||||
@ -649,8 +652,8 @@ class Cain(Optimizer):
|
|||||||
orig_dim = dim
|
orig_dim = dim
|
||||||
step = state["step"]
|
step = state["step"]
|
||||||
must_store_stats, must_zero_stats = self._must_store_stats(step)
|
must_store_stats, must_zero_stats = self._must_store_stats(step)
|
||||||
if must_store_stats:
|
if must_store_stats and allow_store_stats:
|
||||||
# store stats for 20 iters preceding when we estimate the
|
# store stats for 100 iters preceding when we estimate the
|
||||||
# transform.
|
# transform.
|
||||||
stats_name = f"cov_{orig_dim}"
|
stats_name = f"cov_{orig_dim}"
|
||||||
if not stats_name in state:
|
if not stats_name in state:
|
||||||
@ -695,11 +698,11 @@ class Cain(Optimizer):
|
|||||||
"""
|
"""
|
||||||
if step < 4000:
|
if step < 4000:
|
||||||
if step < 2000:
|
if step < 2000:
|
||||||
return (step % 500 >= 480, step % 500 == 480)
|
return (step % 500 >= 400, step % 500 == 400)
|
||||||
else:
|
else:
|
||||||
return (step % 1000 >= 980, step % 1000 == 980)
|
return (step % 1000 >= 800, step % 1000 == 800)
|
||||||
else:
|
else:
|
||||||
return (step % 2000 >= 1980, step % 2000 == 1980)
|
return (step % 2000 >= 1800, step % 2000 == 1800)
|
||||||
|
|
||||||
def _must_estimate_transform(self, step: int) -> bool:
|
def _must_estimate_transform(self, step: int) -> bool:
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user