mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 09:04:19 +00:00
Implement new, more principled but maybe slower version.
This commit is contained in:
parent
75c822c7e9
commit
a6050cb2de
@ -736,52 +736,111 @@ class Decorrelate(torch.nn.Module):
|
|||||||
max_dropout_rate: This is an upper limit, for safety, on how aggressive the
|
max_dropout_rate: This is an upper limit, for safety, on how aggressive the
|
||||||
randomization can be.
|
randomization can be.
|
||||||
eps: An epsilon used to prevent division by zero.
|
eps: An epsilon used to prevent division by zero.
|
||||||
|
beta: A value 0 < beta < 1 that controls decay of covariance stats
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_channels: int,
|
num_channels: int,
|
||||||
apply_prob: float = 0.25,
|
apply_prob: float = 0.25,
|
||||||
dropout_rate: float = 0.01,
|
dropout_rate: float = 0.1,
|
||||||
max_dropout_rate: float = 0.1,
|
|
||||||
eps: float = 1.0e-04,
|
eps: float = 1.0e-04,
|
||||||
|
beta: float = 0.95,
|
||||||
channel_dim: int = -1):
|
channel_dim: int = -1):
|
||||||
super(Decorrelate, self).__init__()
|
super(Decorrelate, self).__init__()
|
||||||
self.apply_prob = apply_prob
|
self.apply_prob = apply_prob
|
||||||
self.dropout_rate = dropout_rate
|
self.dropout_rate = dropout_rate
|
||||||
self.max_dropout_rate = max_dropout_rate
|
|
||||||
self.channel_dim = channel_dim
|
self.channel_dim = channel_dim
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
self.beta = beta
|
||||||
|
|
||||||
rand_mat = torch.randn(num_channels, num_channels)
|
rand_mat = torch.randn(num_channels, num_channels)
|
||||||
U, _, _ = rand_mat.svd()
|
U, _, _ = rand_mat.svd()
|
||||||
|
|
||||||
self.register_buffer('U', U) # a random orthogonal square matrix. will be a buffer.
|
self.register_buffer('U', U) # a random orthogonal square matrix. will be a buffer.
|
||||||
|
|
||||||
|
self.register_buffer('T1', torch.eye(num_channels))
|
||||||
|
self.register_buffer('rand_scales', torch.zeros(num_channels))
|
||||||
|
self.register_buffer('nonrand_scales', torch.ones(num_channels))
|
||||||
|
self.register_buffer('T2', torch.eye(num_channels))
|
||||||
|
self.register_buffer('cov', torch.zeros(num_channels, num_channels))
|
||||||
|
self.step = 0
|
||||||
|
|
||||||
def _get_covar(self, x: Tensor) -> Tensor:
|
|
||||||
|
|
||||||
|
def _update_covar_stats(self, x: Tensor) -> None:
|
||||||
"""
|
"""
|
||||||
Returns the uncentered covariance matrix associated with feature matrix x, detached
|
|
||||||
from its input.
|
|
||||||
Args:
|
Args:
|
||||||
x: Tensor of shape (*, num_channels)
|
x: Tensor of shape (*, num_channels)
|
||||||
Returns:
|
Updates covariance stats self.cov
|
||||||
Covariance matrix `cov`, of shape (num_channels, num_channels)
|
|
||||||
"""
|
"""
|
||||||
x = x.detach()
|
x = x.detach()
|
||||||
x = x.reshape(-1, x.shape[-1])
|
x = x.reshape(-1, x.shape[-1])
|
||||||
x = x * (x.shape[0] ** -0.5) # avoid overflow in half precision
|
x = x * (x.shape[0] ** -0.5) # avoid overflow in half precision
|
||||||
return torch.matmul(x.t(), x)
|
cov = torch.matmul(x.t(), x)
|
||||||
|
self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
|
||||||
|
self.step += 1
|
||||||
|
|
||||||
def _normalize_covar(self, cov: Tensor, eps: float) -> Tensor:
|
def _update_transforms(self):
|
||||||
|
|
||||||
|
norm_cov, inv_sqrt_diag = self._normalize_covar(self.cov)
|
||||||
|
|
||||||
|
U, S, _ = norm_cov.svd()
|
||||||
|
|
||||||
|
if random.random() < 0.1:
|
||||||
|
print("Decorrelate: max,min eig of normalized cov is: {S.max().item():.2e},{S.min().item():.2e}")
|
||||||
|
|
||||||
|
# row indexes of U correspond to channels, column indexes correspond to
|
||||||
|
# singular values: cov = U * diag(S) * U.t() where * is matmul.
|
||||||
|
S_eps = S + self.eps
|
||||||
|
S_sqrt = S_eps ** 0.5
|
||||||
|
S_inv_sqrt = (S + self.eps) ** -0.5
|
||||||
|
|
||||||
|
# Transform T1, which we'll incorporate as torch.matmul(x, self.T1), is:
|
||||||
|
# (i) multiply by inv_sqrt_diag which makes the covariance have
|
||||||
|
# a unit diagonal.
|
||||||
|
# (ii) multiply by U, which diagonalizes norm_cov (uncorrelated channels)
|
||||||
|
# (iii) divide by S_sqrt, which makes all dims have unit variance.
|
||||||
|
self.T1[:] = (inv_sqrt_diag.unsqueeze(-1) * U / S_sqrt)
|
||||||
|
|
||||||
|
# Transform T1, which we'll incorporate as torch.matmul(x, self.TT), is:
|
||||||
|
# (i) multiply by S_sqrt, which restors the variance of different dims,
|
||||||
|
# (ii) multiply by U, which diagonalizes norm_cov (uncorrelated channels)
|
||||||
|
# (iii) divide by inv_sqrt_diag which makes the covariance have its original
|
||||||
|
# diagonal values.
|
||||||
|
self.T2[:] = (S_sqrt.unsqueeze(-1) * U.t() / inv_sqrt_diag)
|
||||||
|
|
||||||
|
|
||||||
|
# OK, now get rand_scales, which are values between 0 and self.dropout_rate; it says
|
||||||
|
# how much randomness will be in different eigenvalues of norm_cov.
|
||||||
|
# Basically, we want more randomness in directions with eigenvalues more than one,
|
||||||
|
# and none in those with eigenvalues less than one.
|
||||||
|
rand_proportion = (S - 1.0).clamp(min=0.0, max=1.0) * self.dropout_rate
|
||||||
|
|
||||||
|
# rand_proportion is viewed as representing a proportion of the covariance, since
|
||||||
|
# the random and nonrandom components will not be correlated.
|
||||||
|
self.rand_scales = rand_proportion.sqrt()
|
||||||
|
self.nonrand_scales = (1.0 - rand_proportion).sqrt()
|
||||||
|
|
||||||
|
|
||||||
|
if True:
|
||||||
|
d = torch.matmul(self.T1, self.T2) - torch.eye(self.T1.shape[0],
|
||||||
|
device=self.T1.device,
|
||||||
|
dtype=self.T1.dtype)
|
||||||
|
assert torch.all(d.abs() < 0.01)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_covar(self, cov: Tensor) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Normlizes a covariance matrix so that its diagonal is 1, by multiplying by
|
Normlizes a covariance matrix so that its diagonal is 1, by multiplying by
|
||||||
its diagonal**-0.5 on both sides.
|
its diagonal**-0.5 on both sides.
|
||||||
Args:
|
Args:
|
||||||
cov: matrix to normalize
|
cov: matrix to normalize
|
||||||
eps: floating point value >0, used to prevent division by zero.
|
|
||||||
Returns normalized_cov, inv_sqrt_diag
|
Returns normalized_cov, inv_sqrt_diag
|
||||||
"""
|
"""
|
||||||
diag = cov.diag()
|
diag = cov.diag()
|
||||||
inv_sqrt_diag = (diag + eps) ** -0.5
|
inv_sqrt_diag = (diag + self.eps) ** -0.5
|
||||||
cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1))
|
cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1))
|
||||||
|
assert torch.all((cov.diag() - 1.0).abs() < 0.1) # TODO: remove
|
||||||
return cov, inv_sqrt_diag
|
return cov, inv_sqrt_diag
|
||||||
|
|
||||||
|
|
||||||
@ -818,31 +877,19 @@ class Decorrelate(torch.nn.Module):
|
|||||||
return x
|
return x
|
||||||
else:
|
else:
|
||||||
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
|
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
|
||||||
x_bypass = x # will be used for "+ I"
|
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
cov = self._get_covar(x)
|
self._update_covar_stats(x)
|
||||||
cov, inv_sqrt_diag = self._normalize_covar(cov, self.eps)
|
if self.step % 50 == 0 or __name__ == "__main__":
|
||||||
avg_squared_eig = (cov**2).sum(dim=0).mean()
|
self._update_transforms()
|
||||||
if random.random() < 0.001 or __name__ == "__main__":
|
|
||||||
logging.info(f"Decorrelate: avg_squared_eig = {avg_squared_eig}")
|
|
||||||
|
|
||||||
# the odd-looking formula below was obtained empirically, to match
|
x = torch.matmul(x, self.T1)
|
||||||
# the self-product and cross-correlation statistics of dropout
|
|
||||||
|
|
||||||
x = x * inv_sqrt_diag
|
x_bypass = x
|
||||||
|
|
||||||
rand_scale1 = ((self.max_dropout_rate / (1.0 - self.max_dropout_rate)) ** 0.5) / avg_squared_eig
|
|
||||||
rand_scale2 = ((self.dropout_rate / (1.0 - self.dropout_rate)) ** 0.5)
|
|
||||||
rand_scale = torch.minimum(rand_scale1, torch.tensor(rand_scale2, device=x.device))
|
|
||||||
|
|
||||||
# by multiplying by `cov`, then randomizing the sign of elements, then
|
|
||||||
# multiplying by `cov` again, we are generating something that has
|
|
||||||
# more noise in directions corresponding to larger eigenvalues of `cov`.
|
|
||||||
# (Actually we scale by the square of the eigenvalue, which is not very
|
|
||||||
# desirable, but was easy to implement in a fast way
|
|
||||||
x = torch.matmul(x * rand_scale, cov)
|
|
||||||
|
|
||||||
|
if True:
|
||||||
|
# This block, in effect, multiplies x by a random orthogonal matrix,
|
||||||
|
# giving us random noise.
|
||||||
perm = self._randperm_like(x)
|
perm = self._randperm_like(x)
|
||||||
x = torch.gather(x, -1, perm)
|
x = torch.gather(x, -1, perm)
|
||||||
# self.U will act like a different matrix for every row of x,
|
# self.U will act like a different matrix for every row of x,
|
||||||
@ -854,10 +901,9 @@ class Decorrelate(torch.nn.Module):
|
|||||||
x_next.scatter_(-1, perm, x)
|
x_next.scatter_(-1, perm, x)
|
||||||
x = x_next
|
x = x_next
|
||||||
|
|
||||||
x = torch.matmul(x, cov)
|
x = (x * self.rand_scales) + (x_bypass * self.nonrand_scales)
|
||||||
x = x / inv_sqrt_diag
|
|
||||||
|
|
||||||
x = x + x_bypass
|
x = torch.matmul(x, self.T2)
|
||||||
x = x.transpose(self.channel_dim, -1)
|
x = x.transpose(self.channel_dim, -1)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -938,12 +984,13 @@ def _test_double_swish_deriv():
|
|||||||
|
|
||||||
|
|
||||||
def _test_gauss_proj_drop():
|
def _test_gauss_proj_drop():
|
||||||
x = torch.randn(30000, 384)
|
D = 384
|
||||||
|
x = torch.randn(30000, D)
|
||||||
|
|
||||||
|
|
||||||
for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
|
for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
|
||||||
m1 = torch.nn.Dropout(dropout_rate)
|
m1 = torch.nn.Dropout(dropout_rate)
|
||||||
m2 = GaussProjDrop(384, dropout_rate)
|
m2 = GaussProjDrop(D, dropout_rate)
|
||||||
for mode in ['train', 'eval']:
|
for mode in ['train', 'eval']:
|
||||||
y1 = m1(x)
|
y1 = m1(x)
|
||||||
y2 = m2(x)
|
y2 = m2(x)
|
||||||
@ -958,12 +1005,17 @@ def _test_gauss_proj_drop():
|
|||||||
|
|
||||||
def _test_decorrelate():
|
def _test_decorrelate():
|
||||||
logging.getLogger().setLevel(logging.INFO)
|
logging.getLogger().setLevel(logging.INFO)
|
||||||
x = torch.randn(30000, 384)
|
D = 384
|
||||||
|
x = torch.randn(30000, D)
|
||||||
|
|
||||||
|
# give it a non-unit covariance.
|
||||||
|
m = torch.randn(D, D)
|
||||||
|
x = torch.matmul(x, m)
|
||||||
|
|
||||||
|
|
||||||
for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
|
for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
|
||||||
m1 = torch.nn.Dropout(dropout_rate)
|
m1 = torch.nn.Dropout(dropout_rate)
|
||||||
m2 = Decorrelate(384, apply_prob=1.0, dropout_rate=dropout_rate, max_dropout_rate=dropout_rate)
|
m2 = Decorrelate(D, apply_prob=1.0, dropout_rate=dropout_rate)
|
||||||
for mode in ['train', 'eval']:
|
for mode in ['train', 'eval']:
|
||||||
y1 = m1(x)
|
y1 = m1(x)
|
||||||
y2 = m2(x)
|
y2 = m2(x)
|
||||||
@ -972,7 +1024,7 @@ def _test_decorrelate():
|
|||||||
cross1 = (x*y1).mean()
|
cross1 = (x*y1).mean()
|
||||||
y2mag = (y2*y2).mean()
|
y2mag = (y2*y2).mean()
|
||||||
cross2 = (x*y2).mean()
|
cross2 = (x*y2).mean()
|
||||||
print(f"rate={dropout_rate}, mode={mode}, xmag = {xmag}, y1mag = {y1mag}, y2mag = {y2mag}, cross1={cross1}, cross2={cross2}")
|
print(f"rate={dropout_rate}, mode={mode}, xmag = {xmag}, y1mag = {y1mag}, y2mag = {y2mag}, cross1={cross1}, cross2={cross2}, ratio1={y1mag/cross1}, ratio2={y2mag/cross2}")
|
||||||
m1.eval()
|
m1.eval()
|
||||||
m2.eval()
|
m2.eval()
|
||||||
|
|
||||||
|
@ -199,7 +199,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.dropout = torch.nn.Dropout(dropout)
|
self.dropout = torch.nn.Dropout(dropout)
|
||||||
self.decorrelate = Decorrelate(d_model, apply_prob=0.25)
|
self.decorrelate = Decorrelate(d_model, apply_prob=0.25, dropout_rate=0.2)
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user