mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
Pre and post-multiply by inv_sqrt_stddev,stddev
This commit is contained in:
parent
a270973b69
commit
75c822c7e9
@ -777,11 +777,12 @@ class Decorrelate(torch.nn.Module):
|
||||
Args:
|
||||
cov: matrix to normalize
|
||||
eps: floating point value >0, used to prevent division by zero.
|
||||
Returns normalized_cov, inv_sqrt_diag
|
||||
"""
|
||||
diag = cov.diag()
|
||||
inv_sqrt_diag = (diag + eps) ** -0.5
|
||||
cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1))
|
||||
return cov
|
||||
return cov, inv_sqrt_diag
|
||||
|
||||
|
||||
def _randperm_like(self, x: Tensor):
|
||||
@ -821,7 +822,7 @@ class Decorrelate(torch.nn.Module):
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
cov = self._get_covar(x)
|
||||
cov = self._normalize_covar(cov, self.eps)
|
||||
cov, inv_sqrt_diag = self._normalize_covar(cov, self.eps)
|
||||
avg_squared_eig = (cov**2).sum(dim=0).mean()
|
||||
if random.random() < 0.001 or __name__ == "__main__":
|
||||
logging.info(f"Decorrelate: avg_squared_eig = {avg_squared_eig}")
|
||||
@ -829,13 +830,15 @@ class Decorrelate(torch.nn.Module):
|
||||
# the odd-looking formula below was obtained empirically, to match
|
||||
# the self-product and cross-correlation statistics of dropout
|
||||
|
||||
x = x * inv_sqrt_diag
|
||||
|
||||
rand_scale1 = ((self.max_dropout_rate / (1.0 - self.max_dropout_rate)) ** 0.5) / avg_squared_eig
|
||||
rand_scale2 = ((self.dropout_rate / (1.0 - self.dropout_rate)) ** 0.5)
|
||||
rand_scale = torch.minimum(rand_scale1, torch.tensor(rand_scale2, device=x.device))
|
||||
|
||||
# by multiplying by `cov`, then randomizing the sign of elements, then
|
||||
# multiplying by `cov` again, we are generating something that has
|
||||
# more noise in directions corresponding to larger eigenvlues of `cov`.
|
||||
# more noise in directions corresponding to larger eigenvalues of `cov`.
|
||||
# (Actually we scale by the square of the eigenvalue, which is not very
|
||||
# desirable, but was easy to implement in a fast way
|
||||
x = torch.matmul(x * rand_scale, cov)
|
||||
@ -852,6 +855,8 @@ class Decorrelate(torch.nn.Module):
|
||||
x = x_next
|
||||
|
||||
x = torch.matmul(x, cov)
|
||||
x = x / inv_sqrt_diag
|
||||
|
||||
x = x + x_bypass
|
||||
x = x.transpose(self.channel_dim, -1)
|
||||
return x
|
||||
|
Loading…
x
Reference in New Issue
Block a user