mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Pre and post-multiply by inv_sqrt_stddev,stddev
This commit is contained in:
parent
a270973b69
commit
75c822c7e9
@ -777,11 +777,12 @@ class Decorrelate(torch.nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
cov: matrix to normalize
|
cov: matrix to normalize
|
||||||
eps: floating point value >0, used to prevent division by zero.
|
eps: floating point value >0, used to prevent division by zero.
|
||||||
|
Returns normalized_cov, inv_sqrt_diag
|
||||||
"""
|
"""
|
||||||
diag = cov.diag()
|
diag = cov.diag()
|
||||||
inv_sqrt_diag = (diag + eps) ** -0.5
|
inv_sqrt_diag = (diag + eps) ** -0.5
|
||||||
cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1))
|
cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1))
|
||||||
return cov
|
return cov, inv_sqrt_diag
|
||||||
|
|
||||||
|
|
||||||
def _randperm_like(self, x: Tensor):
|
def _randperm_like(self, x: Tensor):
|
||||||
@ -821,7 +822,7 @@ class Decorrelate(torch.nn.Module):
|
|||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
cov = self._get_covar(x)
|
cov = self._get_covar(x)
|
||||||
cov = self._normalize_covar(cov, self.eps)
|
cov, inv_sqrt_diag = self._normalize_covar(cov, self.eps)
|
||||||
avg_squared_eig = (cov**2).sum(dim=0).mean()
|
avg_squared_eig = (cov**2).sum(dim=0).mean()
|
||||||
if random.random() < 0.001 or __name__ == "__main__":
|
if random.random() < 0.001 or __name__ == "__main__":
|
||||||
logging.info(f"Decorrelate: avg_squared_eig = {avg_squared_eig}")
|
logging.info(f"Decorrelate: avg_squared_eig = {avg_squared_eig}")
|
||||||
@ -829,13 +830,15 @@ class Decorrelate(torch.nn.Module):
|
|||||||
# the odd-looking formula below was obtained empirically, to match
|
# the odd-looking formula below was obtained empirically, to match
|
||||||
# the self-product and cross-correlation statistics of dropout
|
# the self-product and cross-correlation statistics of dropout
|
||||||
|
|
||||||
|
x = x * inv_sqrt_diag
|
||||||
|
|
||||||
rand_scale1 = ((self.max_dropout_rate / (1.0 - self.max_dropout_rate)) ** 0.5) / avg_squared_eig
|
rand_scale1 = ((self.max_dropout_rate / (1.0 - self.max_dropout_rate)) ** 0.5) / avg_squared_eig
|
||||||
rand_scale2 = ((self.dropout_rate / (1.0 - self.dropout_rate)) ** 0.5)
|
rand_scale2 = ((self.dropout_rate / (1.0 - self.dropout_rate)) ** 0.5)
|
||||||
rand_scale = torch.minimum(rand_scale1, torch.tensor(rand_scale2, device=x.device))
|
rand_scale = torch.minimum(rand_scale1, torch.tensor(rand_scale2, device=x.device))
|
||||||
|
|
||||||
# by multiplying by `cov`, then randomizing the sign of elements, then
|
# by multiplying by `cov`, then randomizing the sign of elements, then
|
||||||
# multiplying by `cov` again, we are generating something that has
|
# multiplying by `cov` again, we are generating something that has
|
||||||
# more noise in directions corresponding to larger eigenvlues of `cov`.
|
# more noise in directions corresponding to larger eigenvalues of `cov`.
|
||||||
# (Actually we scale by the square of the eigenvalue, which is not very
|
# (Actually we scale by the square of the eigenvalue, which is not very
|
||||||
# desirable, but was easy to implement in a fast way
|
# desirable, but was easy to implement in a fast way
|
||||||
x = torch.matmul(x * rand_scale, cov)
|
x = torch.matmul(x * rand_scale, cov)
|
||||||
@ -852,6 +855,8 @@ class Decorrelate(torch.nn.Module):
|
|||||||
x = x_next
|
x = x_next
|
||||||
|
|
||||||
x = torch.matmul(x, cov)
|
x = torch.matmul(x, cov)
|
||||||
|
x = x / inv_sqrt_diag
|
||||||
|
|
||||||
x = x + x_bypass
|
x = x + x_bypass
|
||||||
x = x.transpose(self.channel_dim, -1)
|
x = x.transpose(self.channel_dim, -1)
|
||||||
return x
|
return x
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user