mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Make random cov have correct diagonal; reset grad_min_rand to 0.
This commit is contained in:
parent
3f0b57403e
commit
3fb3cc4e23
@ -93,7 +93,7 @@ class NeutralGradient(Optimizer):
|
||||
estimate_period=2000,
|
||||
stats_steps=200,
|
||||
param_pow=1.0,
|
||||
grad_min_rand=0.2,
|
||||
grad_min_rand=0.0,
|
||||
lr_for_speedup=0.03,
|
||||
speedup_recalibrate_period=100,
|
||||
speedup_first_step=500,
|
||||
@ -512,7 +512,8 @@ class NeutralGradient(Optimizer):
|
||||
grad_cov = state[f"grad_cov_{dim}"] / count
|
||||
del state[f"grad_cov_{dim}"] # save memory
|
||||
|
||||
self._randomize_lowrank_cov(grad_cov, count, p.shape)
|
||||
self._randomize_lowrank_cov(grad_cov, count, p.shape,
|
||||
grad_min_rand)
|
||||
|
||||
param_cov = self._get_param_cov(p, dim)
|
||||
|
||||
@ -703,22 +704,27 @@ class NeutralGradient(Optimizer):
|
||||
size = cov.shape[0]
|
||||
required_rank = int(size * 1.2) + 1
|
||||
|
||||
if rank < required_rank:
|
||||
rand_scale = min_rand
|
||||
if rank < required_rank:
|
||||
# This epsilon is to avoid division by zero in weird situations where the
|
||||
# params are exactly zero or we sample a zero matrix; the exact value is not
|
||||
# going to affect the performance.
|
||||
eps = 1.0e-20
|
||||
cov_scale = cov.diag().mean() + 1.0e-20
|
||||
|
||||
# The following formula assumes that the "missing" outer products are
|
||||
# expected to be smaller than the ones that we do have, i.e. 0.2 the size.
|
||||
# We are trying to find the trace of the matrix we need to add,
|
||||
# i.e. how big it is relative to the trace of the existing matrix.
|
||||
missing_rank = (required_rank - rank)
|
||||
rand_scale = max(min_rand, 0.2 * missing_rank / rank)
|
||||
rand_scale = max(rand_scale, 0.2 * missing_rank / rank)
|
||||
if rand_scale > 0.0:
|
||||
R = torch.randn(size, size, device=cov.device, dtype=cov.dtype)
|
||||
R = torch.matmul(R, R.t()) # positive semidefinite random matrix
|
||||
rms_diag = (cov.diag() + 1.0e-20).sqrt()
|
||||
R = R * rms_diag * rms_diag.unsqueeze(-1)
|
||||
|
||||
R_scale = R.diag().mean() + 1.0e-20
|
||||
cov_scale = cov.diag().mean() + 1.0e-20
|
||||
if random.random() < 0.1:
|
||||
print(f"required_rank={required_rank}, rank={rank}, size={size}, R_scale={R_scale}, rand_scale={rand_scale}, shape={shape}")
|
||||
cov.add_(R, alpha=rand_scale * cov_scale / R_scale)
|
||||
|
Loading…
x
Reference in New Issue
Block a user