Implement new, more principled but maybe slower version.

This commit is contained in:
Daniel Povey 2022-06-07 23:38:38 +08:00
parent 75c822c7e9
commit a6050cb2de
2 changed files with 103 additions and 51 deletions

View File

@ -736,52 +736,111 @@ class Decorrelate(torch.nn.Module):
max_dropout_rate: This is an upper limit, for safety, on how aggressive the
randomization can be.
eps: An epsilon used to prevent division by zero.
beta: A value 0 < beta < 1 that controls decay of covariance stats
"""
def __init__(self,
num_channels: int,
apply_prob: float = 0.25,
dropout_rate: float = 0.01,
max_dropout_rate: float = 0.1,
dropout_rate: float = 0.1,
eps: float = 1.0e-04,
beta: float = 0.95,
channel_dim: int = -1):
super(Decorrelate, self).__init__()
self.apply_prob = apply_prob
self.dropout_rate = dropout_rate
self.max_dropout_rate = max_dropout_rate
self.channel_dim = channel_dim
self.eps = eps
self.beta = beta
rand_mat = torch.randn(num_channels, num_channels)
U, _, _ = rand_mat.svd()
self.register_buffer('U', U) # a random orthogonal square matrix. will be a buffer.
self.register_buffer('T1', torch.eye(num_channels))
self.register_buffer('rand_scales', torch.zeros(num_channels))
self.register_buffer('nonrand_scales', torch.ones(num_channels))
self.register_buffer('T2', torch.eye(num_channels))
self.register_buffer('cov', torch.zeros(num_channels, num_channels))
self.step = 0
def _get_covar(self, x: Tensor) -> Tensor:
def _update_covar_stats(self, x: Tensor) -> None:
"""
Returns the uncentered covariance matrix associated with feature matrix x, detached
from its input.
Args:
x: Tensor of shape (*, num_channels)
Returns:
Covariance matrix `cov`, of shape (num_channels, num_channels)
Updates covariance stats self.cov
"""
x = x.detach()
x = x.reshape(-1, x.shape[-1])
x = x * (x.shape[0] ** -0.5) # avoid overflow in half precision
return torch.matmul(x.t(), x)
cov = torch.matmul(x.t(), x)
self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
self.step += 1
def _normalize_covar(self, cov: Tensor, eps: float) -> Tensor:
def _update_transforms(self):
norm_cov, inv_sqrt_diag = self._normalize_covar(self.cov)
U, S, _ = norm_cov.svd()
if random.random() < 0.1:
print("Decorrelate: max,min eig of normalized cov is: {S.max().item():.2e},{S.min().item():.2e}")
# row indexes of U correspond to channels, column indexes correspond to
# singular values: cov = U * diag(S) * U.t() where * is matmul.
S_eps = S + self.eps
S_sqrt = S_eps ** 0.5
S_inv_sqrt = (S + self.eps) ** -0.5
# Transform T1, which we'll incorporate as torch.matmul(x, self.T1), is:
# (i) multiply by inv_sqrt_diag which makes the covariance have
# a unit diagonal.
# (ii) multiply by U, which diagonalizes norm_cov (uncorrelated channels)
# (iii) divide by S_sqrt, which makes all dims have unit variance.
self.T1[:] = (inv_sqrt_diag.unsqueeze(-1) * U / S_sqrt)
# Transform T1, which we'll incorporate as torch.matmul(x, self.TT), is:
# (i) multiply by S_sqrt, which restors the variance of different dims,
# (ii) multiply by U, which diagonalizes norm_cov (uncorrelated channels)
# (iii) divide by inv_sqrt_diag which makes the covariance have its original
# diagonal values.
self.T2[:] = (S_sqrt.unsqueeze(-1) * U.t() / inv_sqrt_diag)
# OK, now get rand_scales, which are values between 0 and self.dropout_rate; it says
# how much randomness will be in different eigenvalues of norm_cov.
# Basically, we want more randomness in directions with eigenvalues more than one,
# and none in those with eigenvalues less than one.
rand_proportion = (S - 1.0).clamp(min=0.0, max=1.0) * self.dropout_rate
# rand_proportion is viewed as representing a proportion of the covariance, since
# the random and nonrandom components will not be correlated.
self.rand_scales = rand_proportion.sqrt()
self.nonrand_scales = (1.0 - rand_proportion).sqrt()
if True:
d = torch.matmul(self.T1, self.T2) - torch.eye(self.T1.shape[0],
device=self.T1.device,
dtype=self.T1.dtype)
assert torch.all(d.abs() < 0.01)
def _normalize_covar(self, cov: Tensor) -> Tensor:
"""
Normlizes a covariance matrix so that its diagonal is 1, by multiplying by
its diagonal**-0.5 on both sides.
Args:
cov: matrix to normalize
eps: floating point value >0, used to prevent division by zero.
Returns normalized_cov, inv_sqrt_diag
"""
diag = cov.diag()
inv_sqrt_diag = (diag + eps) ** -0.5
inv_sqrt_diag = (diag + self.eps) ** -0.5
cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1))
assert torch.all((cov.diag() - 1.0).abs() < 0.1) # TODO: remove
return cov, inv_sqrt_diag
@ -818,46 +877,33 @@ class Decorrelate(torch.nn.Module):
return x
else:
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
x_bypass = x # will be used for "+ I"
with torch.cuda.amp.autocast(enabled=False):
cov = self._get_covar(x)
cov, inv_sqrt_diag = self._normalize_covar(cov, self.eps)
avg_squared_eig = (cov**2).sum(dim=0).mean()
if random.random() < 0.001 or __name__ == "__main__":
logging.info(f"Decorrelate: avg_squared_eig = {avg_squared_eig}")
self._update_covar_stats(x)
if self.step % 50 == 0 or __name__ == "__main__":
self._update_transforms()
# the odd-looking formula below was obtained empirically, to match
# the self-product and cross-correlation statistics of dropout
x = torch.matmul(x, self.T1)
x = x * inv_sqrt_diag
x_bypass = x
rand_scale1 = ((self.max_dropout_rate / (1.0 - self.max_dropout_rate)) ** 0.5) / avg_squared_eig
rand_scale2 = ((self.dropout_rate / (1.0 - self.dropout_rate)) ** 0.5)
rand_scale = torch.minimum(rand_scale1, torch.tensor(rand_scale2, device=x.device))
if True:
# This block, in effect, multiplies x by a random orthogonal matrix,
# giving us random noise.
perm = self._randperm_like(x)
x = torch.gather(x, -1, perm)
# self.U will act like a different matrix for every row of x,
# because of the random permutation.
x = torch.matmul(x, self.U)
x_next = torch.empty_like(x)
# scatter_ uses perm in opposite way
# from gather, inverting it.
x_next.scatter_(-1, perm, x)
x = x_next
# by multiplying by `cov`, then randomizing the sign of elements, then
# multiplying by `cov` again, we are generating something that has
# more noise in directions corresponding to larger eigenvalues of `cov`.
# (Actually we scale by the square of the eigenvalue, which is not very
# desirable, but was easy to implement in a fast way
x = torch.matmul(x * rand_scale, cov)
x = (x * self.rand_scales) + (x_bypass * self.nonrand_scales)
perm = self._randperm_like(x)
x = torch.gather(x, -1, perm)
# self.U will act like a different matrix for every row of x,
# because of the random permutation.
x = torch.matmul(x, self.U)
x_next = torch.empty_like(x)
# scatter_ uses perm in opposite way
# from gather, inverting it.
x_next.scatter_(-1, perm, x)
x = x_next
x = torch.matmul(x, cov)
x = x / inv_sqrt_diag
x = x + x_bypass
x = torch.matmul(x, self.T2)
x = x.transpose(self.channel_dim, -1)
return x
@ -938,12 +984,13 @@ def _test_double_swish_deriv():
def _test_gauss_proj_drop():
x = torch.randn(30000, 384)
D = 384
x = torch.randn(30000, D)
for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
m1 = torch.nn.Dropout(dropout_rate)
m2 = GaussProjDrop(384, dropout_rate)
m2 = GaussProjDrop(D, dropout_rate)
for mode in ['train', 'eval']:
y1 = m1(x)
y2 = m2(x)
@ -958,12 +1005,17 @@ def _test_gauss_proj_drop():
def _test_decorrelate():
logging.getLogger().setLevel(logging.INFO)
x = torch.randn(30000, 384)
D = 384
x = torch.randn(30000, D)
# give it a non-unit covariance.
m = torch.randn(D, D)
x = torch.matmul(x, m)
for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
m1 = torch.nn.Dropout(dropout_rate)
m2 = Decorrelate(384, apply_prob=1.0, dropout_rate=dropout_rate, max_dropout_rate=dropout_rate)
m2 = Decorrelate(D, apply_prob=1.0, dropout_rate=dropout_rate)
for mode in ['train', 'eval']:
y1 = m1(x)
y2 = m2(x)
@ -972,7 +1024,7 @@ def _test_decorrelate():
cross1 = (x*y1).mean()
y2mag = (y2*y2).mean()
cross2 = (x*y2).mean()
print(f"rate={dropout_rate}, mode={mode}, xmag = {xmag}, y1mag = {y1mag}, y2mag = {y2mag}, cross1={cross1}, cross2={cross2}")
print(f"rate={dropout_rate}, mode={mode}, xmag = {xmag}, y1mag = {y1mag}, y2mag = {y2mag}, cross1={cross1}, cross2={cross2}, ratio1={y1mag/cross1}, ratio2={y2mag/cross2}")
m1.eval()
m2.eval()

View File

@ -199,7 +199,7 @@ class ConformerEncoderLayer(nn.Module):
)
self.dropout = torch.nn.Dropout(dropout)
self.decorrelate = Decorrelate(d_model, apply_prob=0.25)
self.decorrelate = Decorrelate(d_model, apply_prob=0.25, dropout_rate=0.2)
def forward(