mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Fix w.r.t. how cov stats are estimated: storing more stats, and now always zeroing the stats.
This commit is contained in:
parent
6f974b32f6
commit
11eac9089e
@ -666,20 +666,17 @@ class Cain(Optimizer):
|
||||
ndim = grad.ndim
|
||||
numel = grad.numel()
|
||||
step = state["step"]
|
||||
for i in range(grad.ndim):
|
||||
|
||||
dim = (i + step) % grad.ndim
|
||||
|
||||
for dim in range(grad.ndim):
|
||||
# see for each dim in turn whether we want to perform any changes in co-ordinates,
|
||||
# or store any stats.
|
||||
size = grad.shape[dim]
|
||||
if size <= 3 or size % 2 == 1 or size >= 2048 or size == numel:
|
||||
# we don't do any such co-ordinate changes in dims that are too small (no point)
|
||||
# or large (too slow), or convolutional (odd). can revisit this later.
|
||||
# we don't do any such co-ordinate changes in dims with sizes
|
||||
# that are too small (no point) or large (too slow), or that are
|
||||
# assumed convolutional (because they are odd). We can revisit
|
||||
# this later.
|
||||
continue
|
||||
allow_store_stats = (i == grad.ndim - 1)
|
||||
grad = self._change_coordinates_for_dim(grad, state, dim, forward,
|
||||
allow_store_stats)
|
||||
grad = self._change_coordinates_for_dim(grad, state, dim, forward)
|
||||
return grad
|
||||
|
||||
|
||||
@ -688,7 +685,6 @@ class Cain(Optimizer):
|
||||
state: dict,
|
||||
dim: int,
|
||||
forward: bool,
|
||||
allow_store_stats: bool,
|
||||
orig_dim: int = -1):
|
||||
assert grad.ndim > 1
|
||||
if not (grad.ndim == 2 and dim == 1):
|
||||
@ -710,7 +706,7 @@ class Cain(Optimizer):
|
||||
new_shape = new_grad.shape
|
||||
new_grad = new_grad.reshape(-1, new_grad.shape[-1])
|
||||
new_grad = self._change_coordinates_for_dim(new_grad, state, 1, forward,
|
||||
allow_store_stats, orig_dim=dim)
|
||||
orig_dim=dim)
|
||||
return new_grad.reshape(new_shape).permute(*rev_dims_order)
|
||||
|
||||
# OK: grad.ndim == 2 and dim == 1
|
||||
@ -719,8 +715,8 @@ class Cain(Optimizer):
|
||||
orig_dim = dim
|
||||
step = state["step"]
|
||||
must_store_stats, must_zero_stats = self._must_store_stats(step)
|
||||
if must_store_stats and allow_store_stats and forward:
|
||||
# store stats for 100 iters preceding when we estimate the
|
||||
if must_store_stats and forward:
|
||||
# store stats for 200 iters preceding when we estimate the
|
||||
# transform.
|
||||
stats_name = f"cov_{orig_dim}"
|
||||
if not stats_name in state:
|
||||
@ -728,7 +724,6 @@ class Cain(Optimizer):
|
||||
device=grad.device)
|
||||
cov = state[stats_name]
|
||||
if must_zero_stats:
|
||||
#print("zero")
|
||||
cov.zero_()
|
||||
cov += torch.matmul(grad.t(), grad) * (1/grad.shape[0])
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user