mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 17:14:20 +00:00
Implement new, more principled but maybe slower version.
This commit is contained in:
parent
75c822c7e9
commit
a6050cb2de
@ -736,52 +736,111 @@ class Decorrelate(torch.nn.Module):
|
||||
max_dropout_rate: This is an upper limit, for safety, on how aggressive the
|
||||
randomization can be.
|
||||
eps: An epsilon used to prevent division by zero.
|
||||
beta: A value 0 < beta < 1 that controls decay of covariance stats
|
||||
"""
|
||||
def __init__(self,
|
||||
num_channels: int,
|
||||
apply_prob: float = 0.25,
|
||||
dropout_rate: float = 0.01,
|
||||
max_dropout_rate: float = 0.1,
|
||||
dropout_rate: float = 0.1,
|
||||
eps: float = 1.0e-04,
|
||||
beta: float = 0.95,
|
||||
channel_dim: int = -1):
|
||||
super(Decorrelate, self).__init__()
|
||||
self.apply_prob = apply_prob
|
||||
self.dropout_rate = dropout_rate
|
||||
self.max_dropout_rate = max_dropout_rate
|
||||
self.channel_dim = channel_dim
|
||||
self.eps = eps
|
||||
self.beta = beta
|
||||
|
||||
rand_mat = torch.randn(num_channels, num_channels)
|
||||
U, _, _ = rand_mat.svd()
|
||||
|
||||
self.register_buffer('U', U) # a random orthogonal square matrix. will be a buffer.
|
||||
|
||||
self.register_buffer('T1', torch.eye(num_channels))
|
||||
self.register_buffer('rand_scales', torch.zeros(num_channels))
|
||||
self.register_buffer('nonrand_scales', torch.ones(num_channels))
|
||||
self.register_buffer('T2', torch.eye(num_channels))
|
||||
self.register_buffer('cov', torch.zeros(num_channels, num_channels))
|
||||
self.step = 0
|
||||
|
||||
def _get_covar(self, x: Tensor) -> Tensor:
|
||||
|
||||
|
||||
def _update_covar_stats(self, x: Tensor) -> None:
|
||||
"""
|
||||
Returns the uncentered covariance matrix associated with feature matrix x, detached
|
||||
from its input.
|
||||
Args:
|
||||
x: Tensor of shape (*, num_channels)
|
||||
Returns:
|
||||
Covariance matrix `cov`, of shape (num_channels, num_channels)
|
||||
Updates covariance stats self.cov
|
||||
"""
|
||||
x = x.detach()
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
x = x * (x.shape[0] ** -0.5) # avoid overflow in half precision
|
||||
return torch.matmul(x.t(), x)
|
||||
cov = torch.matmul(x.t(), x)
|
||||
self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
|
||||
self.step += 1
|
||||
|
||||
def _normalize_covar(self, cov: Tensor, eps: float) -> Tensor:
|
||||
def _update_transforms(self):
|
||||
|
||||
norm_cov, inv_sqrt_diag = self._normalize_covar(self.cov)
|
||||
|
||||
U, S, _ = norm_cov.svd()
|
||||
|
||||
if random.random() < 0.1:
|
||||
print("Decorrelate: max,min eig of normalized cov is: {S.max().item():.2e},{S.min().item():.2e}")
|
||||
|
||||
# row indexes of U correspond to channels, column indexes correspond to
|
||||
# singular values: cov = U * diag(S) * U.t() where * is matmul.
|
||||
S_eps = S + self.eps
|
||||
S_sqrt = S_eps ** 0.5
|
||||
S_inv_sqrt = (S + self.eps) ** -0.5
|
||||
|
||||
# Transform T1, which we'll incorporate as torch.matmul(x, self.T1), is:
|
||||
# (i) multiply by inv_sqrt_diag which makes the covariance have
|
||||
# a unit diagonal.
|
||||
# (ii) multiply by U, which diagonalizes norm_cov (uncorrelated channels)
|
||||
# (iii) divide by S_sqrt, which makes all dims have unit variance.
|
||||
self.T1[:] = (inv_sqrt_diag.unsqueeze(-1) * U / S_sqrt)
|
||||
|
||||
# Transform T1, which we'll incorporate as torch.matmul(x, self.TT), is:
|
||||
# (i) multiply by S_sqrt, which restors the variance of different dims,
|
||||
# (ii) multiply by U, which diagonalizes norm_cov (uncorrelated channels)
|
||||
# (iii) divide by inv_sqrt_diag which makes the covariance have its original
|
||||
# diagonal values.
|
||||
self.T2[:] = (S_sqrt.unsqueeze(-1) * U.t() / inv_sqrt_diag)
|
||||
|
||||
|
||||
# OK, now get rand_scales, which are values between 0 and self.dropout_rate; it says
|
||||
# how much randomness will be in different eigenvalues of norm_cov.
|
||||
# Basically, we want more randomness in directions with eigenvalues more than one,
|
||||
# and none in those with eigenvalues less than one.
|
||||
rand_proportion = (S - 1.0).clamp(min=0.0, max=1.0) * self.dropout_rate
|
||||
|
||||
# rand_proportion is viewed as representing a proportion of the covariance, since
|
||||
# the random and nonrandom components will not be correlated.
|
||||
self.rand_scales = rand_proportion.sqrt()
|
||||
self.nonrand_scales = (1.0 - rand_proportion).sqrt()
|
||||
|
||||
|
||||
if True:
|
||||
d = torch.matmul(self.T1, self.T2) - torch.eye(self.T1.shape[0],
|
||||
device=self.T1.device,
|
||||
dtype=self.T1.dtype)
|
||||
assert torch.all(d.abs() < 0.01)
|
||||
|
||||
|
||||
|
||||
def _normalize_covar(self, cov: Tensor) -> Tensor:
|
||||
"""
|
||||
Normlizes a covariance matrix so that its diagonal is 1, by multiplying by
|
||||
its diagonal**-0.5 on both sides.
|
||||
Args:
|
||||
cov: matrix to normalize
|
||||
eps: floating point value >0, used to prevent division by zero.
|
||||
Returns normalized_cov, inv_sqrt_diag
|
||||
"""
|
||||
diag = cov.diag()
|
||||
inv_sqrt_diag = (diag + eps) ** -0.5
|
||||
inv_sqrt_diag = (diag + self.eps) ** -0.5
|
||||
cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1))
|
||||
assert torch.all((cov.diag() - 1.0).abs() < 0.1) # TODO: remove
|
||||
return cov, inv_sqrt_diag
|
||||
|
||||
|
||||
@ -818,46 +877,33 @@ class Decorrelate(torch.nn.Module):
|
||||
return x
|
||||
else:
|
||||
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
|
||||
x_bypass = x # will be used for "+ I"
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
cov = self._get_covar(x)
|
||||
cov, inv_sqrt_diag = self._normalize_covar(cov, self.eps)
|
||||
avg_squared_eig = (cov**2).sum(dim=0).mean()
|
||||
if random.random() < 0.001 or __name__ == "__main__":
|
||||
logging.info(f"Decorrelate: avg_squared_eig = {avg_squared_eig}")
|
||||
self._update_covar_stats(x)
|
||||
if self.step % 50 == 0 or __name__ == "__main__":
|
||||
self._update_transforms()
|
||||
|
||||
# the odd-looking formula below was obtained empirically, to match
|
||||
# the self-product and cross-correlation statistics of dropout
|
||||
x = torch.matmul(x, self.T1)
|
||||
|
||||
x = x * inv_sqrt_diag
|
||||
x_bypass = x
|
||||
|
||||
rand_scale1 = ((self.max_dropout_rate / (1.0 - self.max_dropout_rate)) ** 0.5) / avg_squared_eig
|
||||
rand_scale2 = ((self.dropout_rate / (1.0 - self.dropout_rate)) ** 0.5)
|
||||
rand_scale = torch.minimum(rand_scale1, torch.tensor(rand_scale2, device=x.device))
|
||||
if True:
|
||||
# This block, in effect, multiplies x by a random orthogonal matrix,
|
||||
# giving us random noise.
|
||||
perm = self._randperm_like(x)
|
||||
x = torch.gather(x, -1, perm)
|
||||
# self.U will act like a different matrix for every row of x,
|
||||
# because of the random permutation.
|
||||
x = torch.matmul(x, self.U)
|
||||
x_next = torch.empty_like(x)
|
||||
# scatter_ uses perm in opposite way
|
||||
# from gather, inverting it.
|
||||
x_next.scatter_(-1, perm, x)
|
||||
x = x_next
|
||||
|
||||
# by multiplying by `cov`, then randomizing the sign of elements, then
|
||||
# multiplying by `cov` again, we are generating something that has
|
||||
# more noise in directions corresponding to larger eigenvalues of `cov`.
|
||||
# (Actually we scale by the square of the eigenvalue, which is not very
|
||||
# desirable, but was easy to implement in a fast way
|
||||
x = torch.matmul(x * rand_scale, cov)
|
||||
x = (x * self.rand_scales) + (x_bypass * self.nonrand_scales)
|
||||
|
||||
perm = self._randperm_like(x)
|
||||
x = torch.gather(x, -1, perm)
|
||||
# self.U will act like a different matrix for every row of x,
|
||||
# because of the random permutation.
|
||||
x = torch.matmul(x, self.U)
|
||||
x_next = torch.empty_like(x)
|
||||
# scatter_ uses perm in opposite way
|
||||
# from gather, inverting it.
|
||||
x_next.scatter_(-1, perm, x)
|
||||
x = x_next
|
||||
|
||||
x = torch.matmul(x, cov)
|
||||
x = x / inv_sqrt_diag
|
||||
|
||||
x = x + x_bypass
|
||||
x = torch.matmul(x, self.T2)
|
||||
x = x.transpose(self.channel_dim, -1)
|
||||
return x
|
||||
|
||||
@ -938,12 +984,13 @@ def _test_double_swish_deriv():
|
||||
|
||||
|
||||
def _test_gauss_proj_drop():
|
||||
x = torch.randn(30000, 384)
|
||||
D = 384
|
||||
x = torch.randn(30000, D)
|
||||
|
||||
|
||||
for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
|
||||
m1 = torch.nn.Dropout(dropout_rate)
|
||||
m2 = GaussProjDrop(384, dropout_rate)
|
||||
m2 = GaussProjDrop(D, dropout_rate)
|
||||
for mode in ['train', 'eval']:
|
||||
y1 = m1(x)
|
||||
y2 = m2(x)
|
||||
@ -958,12 +1005,17 @@ def _test_gauss_proj_drop():
|
||||
|
||||
def _test_decorrelate():
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
x = torch.randn(30000, 384)
|
||||
D = 384
|
||||
x = torch.randn(30000, D)
|
||||
|
||||
# give it a non-unit covariance.
|
||||
m = torch.randn(D, D)
|
||||
x = torch.matmul(x, m)
|
||||
|
||||
|
||||
for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
|
||||
m1 = torch.nn.Dropout(dropout_rate)
|
||||
m2 = Decorrelate(384, apply_prob=1.0, dropout_rate=dropout_rate, max_dropout_rate=dropout_rate)
|
||||
m2 = Decorrelate(D, apply_prob=1.0, dropout_rate=dropout_rate)
|
||||
for mode in ['train', 'eval']:
|
||||
y1 = m1(x)
|
||||
y2 = m2(x)
|
||||
@ -972,7 +1024,7 @@ def _test_decorrelate():
|
||||
cross1 = (x*y1).mean()
|
||||
y2mag = (y2*y2).mean()
|
||||
cross2 = (x*y2).mean()
|
||||
print(f"rate={dropout_rate}, mode={mode}, xmag = {xmag}, y1mag = {y1mag}, y2mag = {y2mag}, cross1={cross1}, cross2={cross2}")
|
||||
print(f"rate={dropout_rate}, mode={mode}, xmag = {xmag}, y1mag = {y1mag}, y2mag = {y2mag}, cross1={cross1}, cross2={cross2}, ratio1={y1mag/cross1}, ratio2={y2mag/cross2}")
|
||||
m1.eval()
|
||||
m2.eval()
|
||||
|
||||
|
@ -199,7 +199,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
)
|
||||
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.decorrelate = Decorrelate(d_model, apply_prob=0.25)
|
||||
self.decorrelate = Decorrelate(d_model, apply_prob=0.25, dropout_rate=0.2)
|
||||
|
||||
|
||||
def forward(
|
||||
|
Loading…
x
Reference in New Issue
Block a user