mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
Add gaussian version of decorrelation
This commit is contained in:
parent
5d24489752
commit
a270973b69
@ -733,23 +733,29 @@ class Decorrelate(torch.nn.Module):
|
||||
statistics match those dropout with the same `dropout_rate`
|
||||
(assuming we applied the transform, e.g. if apply_prob == 1.0)
|
||||
This number applies when the features are un-correlated.
|
||||
dropout_max_rate: This is an upper limit, for safety, on how aggressive the
|
||||
max_dropout_rate: This is an upper limit, for safety, on how aggressive the
|
||||
randomization can be.
|
||||
eps: An epsilon used to prevent division by zero.
|
||||
"""
|
||||
def __init__(self,
|
||||
num_channels: int,
|
||||
apply_prob: float = 0.25,
|
||||
dropout_rate: float = 0.01,
|
||||
dropout_max_rate: float = 0.1,
|
||||
max_dropout_rate: float = 0.1,
|
||||
eps: float = 1.0e-04,
|
||||
channel_dim: int = -1):
|
||||
super(Decorrelate, self).__init__()
|
||||
self.apply_prob = apply_prob
|
||||
self.dropout_rate = dropout_rate
|
||||
self.dropout_max_rate = dropout_max_rate
|
||||
self.max_dropout_rate = max_dropout_rate
|
||||
self.channel_dim = channel_dim
|
||||
self.eps = eps
|
||||
|
||||
rand_mat = torch.randn(num_channels, num_channels)
|
||||
U, _, _ = rand_mat.svd()
|
||||
self.register_buffer('U', U) # a random orthogonal square matrix. will be a buffer.
|
||||
|
||||
|
||||
def _get_covar(self, x: Tensor) -> Tensor:
|
||||
"""
|
||||
Returns the uncentered covariance matrix associated with feature matrix x, detached
|
||||
@ -778,6 +784,34 @@ class Decorrelate(torch.nn.Module):
|
||||
return cov
|
||||
|
||||
|
||||
def _randperm_like(self, x: Tensor):
|
||||
"""
|
||||
Returns random permutations of the integers [0,1,..x.shape[-1]-1],
|
||||
with the same shape as x. All dimensions of x other than the last dimension
|
||||
will be treated as batch dimensions.
|
||||
|
||||
Torch's randperm does not support a batch dimension, so we pseudo-randomly simulate it.
|
||||
|
||||
For now, requires x.shape[-1] to be either a power of 2 or 3 times a power of 2, as
|
||||
we normally set channel dims. This is required for some number theoretic stuff.
|
||||
"""
|
||||
n = x.shape[-1]
|
||||
|
||||
assert n & (n-1) == 0 or (n//3 & (n//3 - 1)) == 0
|
||||
|
||||
b = x.numel() // n
|
||||
randint = random.randint(0, 1000)
|
||||
perm = torch.randperm(n, device=x.device)
|
||||
# ensure all elements of batch_rand are coprime to n; this will ensure
|
||||
# that multiplying the permutation by batch_rand and taking modulo
|
||||
# n leaves us with permutations.
|
||||
batch_rand = torch.arange(b, device=x.device) * (randint * 6) + 1
|
||||
batch_rand = batch_rand.unsqueeze(-1)
|
||||
ans = (perm * batch_rand) % n
|
||||
ans = ans.reshape(x.shape)
|
||||
return ans
|
||||
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if not self.training or random.random() > self.apply_prob:
|
||||
return x
|
||||
@ -795,22 +829,28 @@ class Decorrelate(torch.nn.Module):
|
||||
# the odd-looking formula below was obtained empirically, to match
|
||||
# the self-product and cross-correlation statistics of dropout
|
||||
|
||||
rand_scale1 = ((self.dropout_max_rate / (1.0 - self.dropout_max_rate)) ** 0.5) / avg_squared_eig
|
||||
rand_scale1 = ((self.max_dropout_rate / (1.0 - self.max_dropout_rate)) ** 0.5) / avg_squared_eig
|
||||
rand_scale2 = ((self.dropout_rate / (1.0 - self.dropout_rate)) ** 0.5)
|
||||
rand_scale = torch.minimum(rand_scale1, torch.tensor(rand_scale2, device=x.device))
|
||||
|
||||
|
||||
# by multiplying by `cov`, then randomizing the sign of elements, then
|
||||
# multiplying by `cov` again, we are generating something that has
|
||||
# more noise in directions corresponding to larger eigenvlues of `cov`.
|
||||
# (Actually we scale by the square of the eigenvalue, which is not very
|
||||
# desirable, but was easy to implement in a fast way
|
||||
x = torch.matmul(x * rand_scale, cov)
|
||||
rand_mask = (torch.rand_like(x) > 0.5)
|
||||
# randomize the sign of elements of x.
|
||||
# important to write the expression this way, so that only rand_mask needs
|
||||
# to be stored for backprop.
|
||||
x = x - 2 * (rand_mask * x)
|
||||
|
||||
perm = self._randperm_like(x)
|
||||
x = torch.gather(x, -1, perm)
|
||||
# self.U will act like a different matrix for every row of x,
|
||||
# because of the random permutation.
|
||||
x = torch.matmul(x, self.U)
|
||||
x_next = torch.empty_like(x)
|
||||
# scatter_ uses perm in opposite way
|
||||
# from gather, inverting it.
|
||||
x_next.scatter_(-1, perm, x)
|
||||
x = x_next
|
||||
|
||||
x = torch.matmul(x, cov)
|
||||
x = x + x_bypass
|
||||
x = x.transpose(self.channel_dim, -1)
|
||||
@ -918,7 +958,7 @@ def _test_decorrelate():
|
||||
|
||||
for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
|
||||
m1 = torch.nn.Dropout(dropout_rate)
|
||||
m2 = Decorrelate(apply_prob=1.0, dropout_rate=dropout_rate)
|
||||
m2 = Decorrelate(384, apply_prob=1.0, dropout_rate=dropout_rate, max_dropout_rate=dropout_rate)
|
||||
for mode in ['train', 'eval']:
|
||||
y1 = m1(x)
|
||||
y2 = m2(x)
|
||||
|
@ -199,7 +199,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
)
|
||||
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.decorrelate = Decorrelate(apply_prob=0.25)
|
||||
self.decorrelate = Decorrelate(d_model, apply_prob=0.25)
|
||||
|
||||
|
||||
def forward(
|
||||
|
Loading…
x
Reference in New Issue
Block a user