Pre and post-multiply by inv_sqrt_stddev,stddev

This commit is contained in:
Daniel Povey 2022-06-07 20:32:18 +08:00
parent a270973b69
commit 75c822c7e9

View File

@ -777,11 +777,12 @@ class Decorrelate(torch.nn.Module):
Args:
cov: matrix to normalize
eps: floating point value >0, used to prevent division by zero.
Returns normalized_cov, inv_sqrt_diag
"""
diag = cov.diag()
inv_sqrt_diag = (diag + eps) ** -0.5
cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1))
return cov
return cov, inv_sqrt_diag
def _randperm_like(self, x: Tensor):
@ -821,7 +822,7 @@ class Decorrelate(torch.nn.Module):
with torch.cuda.amp.autocast(enabled=False):
cov = self._get_covar(x)
cov = self._normalize_covar(cov, self.eps)
cov, inv_sqrt_diag = self._normalize_covar(cov, self.eps)
avg_squared_eig = (cov**2).sum(dim=0).mean()
if random.random() < 0.001 or __name__ == "__main__":
logging.info(f"Decorrelate: avg_squared_eig = {avg_squared_eig}")
@ -829,13 +830,15 @@ class Decorrelate(torch.nn.Module):
# the odd-looking formula below was obtained empirically, to match
# the self-product and cross-correlation statistics of dropout
x = x * inv_sqrt_diag
rand_scale1 = ((self.max_dropout_rate / (1.0 - self.max_dropout_rate)) ** 0.5) / avg_squared_eig
rand_scale2 = ((self.dropout_rate / (1.0 - self.dropout_rate)) ** 0.5)
rand_scale = torch.minimum(rand_scale1, torch.tensor(rand_scale2, device=x.device))
# by multiplying by `cov`, then randomizing the sign of elements, then
# multiplying by `cov` again, we are generating something that has
# more noise in directions corresponding to larger eigenvlues of `cov`.
# more noise in directions corresponding to larger eigenvalues of `cov`.
# (Actually we scale by the square of the eigenvalue, which is not very
# desirable, but was easy to implement in a fast way
x = torch.matmul(x * rand_scale, cov)
@ -852,6 +855,8 @@ class Decorrelate(torch.nn.Module):
x = x_next
x = torch.matmul(x, cov)
x = x / inv_sqrt_diag
x = x + x_bypass
x = x.transpose(self.channel_dim, -1)
return x