diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 87e3cbd18..56914344e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -733,23 +733,29 @@ class Decorrelate(torch.nn.Module): statistics match those dropout with the same `dropout_rate` (assuming we applied the transform, e.g. if apply_prob == 1.0) This number applies when the features are un-correlated. - dropout_max_rate: This is an upper limit, for safety, on how aggressive the + max_dropout_rate: This is an upper limit, for safety, on how aggressive the randomization can be. eps: An epsilon used to prevent division by zero. """ def __init__(self, + num_channels: int, apply_prob: float = 0.25, dropout_rate: float = 0.01, - dropout_max_rate: float = 0.1, + max_dropout_rate: float = 0.1, eps: float = 1.0e-04, channel_dim: int = -1): super(Decorrelate, self).__init__() self.apply_prob = apply_prob self.dropout_rate = dropout_rate - self.dropout_max_rate = dropout_max_rate + self.max_dropout_rate = max_dropout_rate self.channel_dim = channel_dim self.eps = eps + rand_mat = torch.randn(num_channels, num_channels) + U, _, _ = rand_mat.svd() + self.register_buffer('U', U) # a random orthogonal square matrix. will be a buffer. + + def _get_covar(self, x: Tensor) -> Tensor: """ Returns the uncentered covariance matrix associated with feature matrix x, detached @@ -778,6 +784,34 @@ class Decorrelate(torch.nn.Module): return cov + def _randperm_like(self, x: Tensor): + """ + Returns random permutations of the integers [0,1,..x.shape[-1]-1], + with the same shape as x. All dimensions of x other than the last dimension + will be treated as batch dimensions. + + Torch's randperm does not support a batch dimension, so we pseudo-randomly simulate it. + + For now, requires x.shape[-1] to be either a power of 2 or 3 times a power of 2, as + we normally set channel dims. This is required for some number theoretic stuff. + """ + n = x.shape[-1] + + assert n & (n-1) == 0 or (n//3 & (n//3 - 1)) == 0 + + b = x.numel() // n + randint = random.randint(0, 1000) + perm = torch.randperm(n, device=x.device) + # ensure all elements of batch_rand are coprime to n; this will ensure + # that multiplying the permutation by batch_rand and taking modulo + # n leaves us with permutations. + batch_rand = torch.arange(b, device=x.device) * (randint * 6) + 1 + batch_rand = batch_rand.unsqueeze(-1) + ans = (perm * batch_rand) % n + ans = ans.reshape(x.shape) + return ans + + def forward(self, x: Tensor) -> Tensor: if not self.training or random.random() > self.apply_prob: return x @@ -795,22 +829,28 @@ class Decorrelate(torch.nn.Module): # the odd-looking formula below was obtained empirically, to match # the self-product and cross-correlation statistics of dropout - rand_scale1 = ((self.dropout_max_rate / (1.0 - self.dropout_max_rate)) ** 0.5) / avg_squared_eig + rand_scale1 = ((self.max_dropout_rate / (1.0 - self.max_dropout_rate)) ** 0.5) / avg_squared_eig rand_scale2 = ((self.dropout_rate / (1.0 - self.dropout_rate)) ** 0.5) rand_scale = torch.minimum(rand_scale1, torch.tensor(rand_scale2, device=x.device)) - # by multiplying by `cov`, then randomizing the sign of elements, then # multiplying by `cov` again, we are generating something that has # more noise in directions corresponding to larger eigenvlues of `cov`. # (Actually we scale by the square of the eigenvalue, which is not very # desirable, but was easy to implement in a fast way x = torch.matmul(x * rand_scale, cov) - rand_mask = (torch.rand_like(x) > 0.5) - # randomize the sign of elements of x. - # important to write the expression this way, so that only rand_mask needs - # to be stored for backprop. - x = x - 2 * (rand_mask * x) + + perm = self._randperm_like(x) + x = torch.gather(x, -1, perm) + # self.U will act like a different matrix for every row of x, + # because of the random permutation. + x = torch.matmul(x, self.U) + x_next = torch.empty_like(x) + # scatter_ uses perm in opposite way + # from gather, inverting it. + x_next.scatter_(-1, perm, x) + x = x_next + x = torch.matmul(x, cov) x = x + x_bypass x = x.transpose(self.channel_dim, -1) @@ -918,7 +958,7 @@ def _test_decorrelate(): for dropout_rate in [0.2, 0.1, 0.01, 0.05]: m1 = torch.nn.Dropout(dropout_rate) - m2 = Decorrelate(apply_prob=1.0, dropout_rate=dropout_rate) + m2 = Decorrelate(384, apply_prob=1.0, dropout_rate=dropout_rate, max_dropout_rate=dropout_rate) for mode in ['train', 'eval']: y1 = m1(x) y2 = m2(x) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 142fa34d8..00ff3f3cc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -199,7 +199,7 @@ class ConformerEncoderLayer(nn.Module): ) self.dropout = torch.nn.Dropout(dropout) - self.decorrelate = Decorrelate(apply_prob=0.25) + self.decorrelate = Decorrelate(d_model, apply_prob=0.25) def forward(