diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 840d847cb..12525873b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -1029,6 +1029,8 @@ class Conv2dSubsampling(nn.Module): if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) feature_dim = 50 c = Conformer(num_features=feature_dim, d_model=128, nhead=4) batch_size = 5 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 6e4884df7..09eadc9e9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -713,114 +713,96 @@ class GaussProjDrop(torch.nn.Module): x = (x_next * self.rand_scale + x_bypass) return x - -class Decorrelate(torch.nn.Module): +class JoinDropout(torch.nn.Module): """ - This module is something similar to dropout; it is a random transformation that - does nothing in eval mode. - It is designed specifically to encourage the input data to be decorrelated, i.e. - to have a diagonal covariance matrix (not necessarily unity). + This module implements something like: + y = bypass + dropout(x) + but does it in such a way as to encourage x to vary in directions that will tend + to make the dimensions of y as decorrelated as possible. We do this + by putting lots of dropout in directions in the space in which we + don't want x to vary (because it will tend to increase correlations between + dimensions in the output y). - To save time, in training mode we only apply it on randomly selected minibatches. Args: num_channels: The number of channels, e.g. 256. apply_prob: The probability with which we apply this each time, in training mode. This is to save time (but of course it will tend to make the effect weaker). - dropout_rate: This number determines the scale of the random multiplicative - noise, in such a way that the self-correlation and cross-correlation - statistics match those dropout with the same `dropout_rate` - (assuming we applied the transform, e.g. if apply_prob == 1.0) - This number applies when the features are un-correlated. - max_dropout_rate: This is an upper limit, for safety, on how aggressive the - randomization can be. - eps: An epsilon used to prevent division by zero. + dropout_rate: This number determines the average dropout probability + (it will actually vary across dimensions). + eps: An epsilon used to prevent division by zero. beta: A value 0 < beta < 1 that controls decay of covariance stats + channel_dim: The dimension of the input corresponding to the channel, e.g. + -1, 0, 1, 2. """ def __init__(self, num_channels: int, - apply_prob: float = 0.25, + apply_prob: float = 0.75, dropout_rate: float = 0.1, eps: float = 1.0e-04, beta: float = 0.95, channel_dim: int = -1): - super(Decorrelate, self).__init__() + super(JoinDropout, self).__init__() self.apply_prob = apply_prob self.dropout_rate = dropout_rate self.channel_dim = channel_dim self.eps = eps self.beta = beta - #rand_mat = torch.randn(num_channels, num_channels) - #U, _, _ = rand_mat.svd() - #self.register_buffer('U', U) # a random orthogonal square matrix. will be a buffer. - self.register_buffer('T1', torch.eye(num_channels)) - self.register_buffer('rand_scales', torch.zeros(num_channels)) - self.register_buffer('nonrand_scales', torch.ones(num_channels)) + self.register_buffer('dropout_probs', torch.zeros(num_channels)) + self.register_buffer('scales', torch.ones(num_channels)) self.register_buffer('T2', torch.eye(num_channels)) self.register_buffer('cov', torch.zeros(num_channels, num_channels)) self.step = 0 - - def _update_covar_stats(self, x: Tensor) -> None: + def _update_covar_stats(self, y: Tensor) -> None: """ Args: - x: Tensor of shape (*, num_channels) + y: Tensor of shape (*, num_channels), of output. Updates covariance stats self.cov """ - x = x.detach() - x = x.reshape(-1, x.shape[-1]) - x = x * (x.shape[0] ** -0.5) # avoid overflow in half precision - cov = torch.matmul(x.t(), x) + y = y.detach() + y = y.reshape(-1, y.shape[-1]) + y = y * (y.shape[0] ** -0.5) # avoid overflow in half precision + cov = torch.matmul(y.t(), y) self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta)) - self.step += 1 def _update_transforms(self): - norm_cov, inv_sqrt_diag = self._normalize_covar(self.cov) - U, S, _ = norm_cov.svd() + U, S, _ = norm_cov.svd() # because diag of norm_cov is 1.0, S.mean() == 1.0 if random.random() < 0.1: - logging.info(f"Decorrelate: max,min eig of normalized cov is: {S.max().item():.2e},{S.min().item():.2e}") + logging.info(f"JoinDropout: max,min eig of normalized cov is: {S.max().item():.2e},{S.min().item():.2e}") + + dropout_probs = (S.sqrt() - 0.99).clamp(min=0) + dropout_probs = dropout_probs * (self.dropout_rate / dropout_probs.mean()) + dropout_probs = dropout_probs.clamp(max=0.5) + self.dropout_probs[:] = dropout_probs + self.scales[:] = 1.0 / (1 - dropout_probs) + # row indexes of U correspond to channels, column indexes correspond to # singular values: cov = U * diag(S) * U.t() where * is matmul. - S_eps = S + self.eps - S_sqrt = S_eps ** 0.5 - S_inv_sqrt = (S + self.eps) ** -0.5 + # Transform T1, which we'll incorporate as torch.matmul(x, self.T1), is: # (i) multiply by inv_sqrt_diag which makes the covariance have # a unit diagonal. # (ii) multiply by U, which diagonalizes norm_cov (uncorrelated channels) - # (iii) divide by S_sqrt, which makes all dims have unit variance. - self.T1[:] = (inv_sqrt_diag.unsqueeze(-1) * U / S_sqrt) + self.T1[:] = (inv_sqrt_diag.unsqueeze(-1) * U) - # Transform T1, which we'll incorporate as torch.matmul(x, self.TT), is: - # (i) multiply by S_sqrt, which restors the variance of different dims, - # (ii) multiply by U, which diagonalizes norm_cov (uncorrelated channels) - # (iii) divide by inv_sqrt_diag which makes the covariance have its original + # Transform T2, which we'll incorporate as torch.matmul(x, self.T2), is: + # (i) multiply by U, which un-diagonalizes norm_cov + # (ii) divide by inv_sqrt_diag which makes the covariance have its original # diagonal values. - self.T2[:] = (S_sqrt.unsqueeze(-1) * U.t() / inv_sqrt_diag) + self.T2[:] = (U.t() / inv_sqrt_diag) - # OK, now get rand_scales, which are values between 0 and self.dropout_rate; it says - # how much randomness will be in different eigenvalues of norm_cov. - # Basically, we want more randomness in directions with eigenvalues more than one, - # and none in those with eigenvalues less than one. - rand_proportion = (S - 1.0).clamp(min=0.0, max=1.0) * self.dropout_rate - - # rand_proportion is viewed as representing a proportion of the covariance, since - # the random and nonrandom components will not be correlated. - self.rand_scales[:] = rand_proportion.sqrt() - self.nonrand_scales[:] = (1.0 - rand_proportion).sqrt() - - - if True: + if random.random() < 0.01: d = torch.matmul(self.T1, self.T2) - torch.eye(self.T1.shape[0], device=self.T1.device, dtype=self.T1.dtype) @@ -839,75 +821,38 @@ class Decorrelate(torch.nn.Module): diag = cov.diag() inv_sqrt_diag = (diag + self.eps) ** -0.5 cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1)) - assert torch.all((cov.diag() - 1.0).abs() < 0.1) # TODO: remove return cov, inv_sqrt_diag - def _randperm_like(self, x: Tensor): - """ - Returns random permutations of the integers [0,1,..x.shape[-1]-1], - with the same shape as x. All dimensions of x other than the last dimension - will be treated as batch dimensions. - Torch's randperm does not support a batch dimension, so we pseudo-randomly simulate it. - - For now, requires x.shape[-1] to be either a power of 2 or 3 times a power of 2, as - we normally set channel dims. This is required for some number theoretic stuff. - """ - n = x.shape[-1] - - assert n & (n-1) == 0 or (n//3 & (n//3 - 1)) == 0 - - b = x.numel() // n - randint = random.randint(0, 1000) - perm = torch.randperm(n, device=x.device) - # ensure all elements of batch_rand are coprime to n; this will ensure - # that multiplying the permutation by batch_rand and taking modulo - # n leaves us with permutations. - batch_rand = torch.arange(b, device=x.device) * (randint * 6) + 1 - batch_rand = batch_rand.unsqueeze(-1) - ans = (perm * batch_rand) % n - ans = ans.reshape(x.shape) - return ans - - - def forward(self, x: Tensor) -> Tensor: + def forward(self, bypass: Tensor, x: Tensor) -> Tensor: if not self.training or random.random() > self.apply_prob: - return x + return bypass + x else: x = x.transpose(self.channel_dim, -1) # (..., num_channels) + bypass = bypass.transpose(self.channel_dim, -1) + x = torch.matmul(x, self.T1.clone()) + + mask = (torch.rand_like(x) > self.dropout_probs) + x = (x * mask) * self.scales.clone() + x = torch.matmul(x, self.T2.clone()) + + y = bypass + x + self.step += 1 with torch.cuda.amp.autocast(enabled=False): - self._update_covar_stats(x) - if self.step % 50 == 0 or __name__ == "__main__": - self._update_transforms() + if self.step % 4 == 0 or __name__ == "__main__": + self._update_covar_stats(y) + if self.step % 40 == 0 or __name__ == "__main__": + # note: important that 40 is a multiple of 4 + self._update_transforms() - x = torch.matmul(x, self.T1) + y = y.transpose(self.channel_dim, -1) + return y - x_bypass = x - if False: - # This block, in effect, multiplies x by a random orthogonal matrix, - # giving us random noise. - perm = self._randperm_like(x) - x = torch.gather(x, -1, perm) - # self.U will act like a different matrix for every row of x, - # because of the random permutation. - x = torch.matmul(x, self.U) - x_next = torch.empty_like(x) - # scatter_ uses perm in opposite way - # from gather, inverting it. - x_next.scatter_(-1, perm, x) - x = x_next - mask = (torch.rand_like(x) > 0.5) - x = x - (x * mask) * 2 - x = (x * self.rand_scales) + (x_bypass * self.nonrand_scales) - - x = torch.matmul(x, self.T2) - x = x.transpose(self.channel_dim, -1) - return x @@ -1005,8 +950,7 @@ def _test_gauss_proj_drop(): m1.eval() m2.eval() -def _test_decorrelate(): - logging.getLogger().setLevel(logging.INFO) +def _test_join_dropout(): D = 384 x = torch.randn(30000, D) @@ -1014,13 +958,14 @@ def _test_decorrelate(): m = torch.randn(D, D) x = torch.matmul(x, m) - for dropout_rate in [0.2, 0.1, 0.01, 0.05]: m1 = torch.nn.Dropout(dropout_rate) - m2 = Decorrelate(D, apply_prob=1.0, dropout_rate=dropout_rate) + m2 = JoinDropout(D, apply_prob=1.0, dropout_rate=dropout_rate) + bypass = torch.zeros_like(x) for mode in ['train', 'eval']: y1 = m1(x) - y2 = m2(x) + for _ in range(2): + y2 = m2(bypass, x) xmag = (x*x).mean() y1mag = (y1*y1).mean() cross1 = (x*y1).mean() @@ -1032,7 +977,10 @@ def _test_decorrelate(): if __name__ == "__main__": - _test_decorrelate() + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_join_dropout() _test_gauss_proj_drop() _test_activation_balancer_sign() _test_activation_balancer_magnitude() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index b64b9a6bc..7440fe05a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -29,7 +29,7 @@ from scaling import ( ScaledConv1d, ScaledConv2d, ScaledLinear, - Decorrelate, + JoinDropout, ) from torch import Tensor, nn @@ -198,8 +198,10 @@ class ConformerEncoderLayer(nn.Module): channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 ) - self.dropout = torch.nn.Dropout(dropout) - self.decorrelate = Decorrelate(d_model, apply_prob=0.25, dropout_rate=0.05) + self.dropout_ff_macaron = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout) + self.dropout_conv = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout) + self.dropout_self_attn = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout) + self.dropout_ff = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout) def forward( @@ -243,7 +245,7 @@ class ConformerEncoderLayer(nn.Module): alpha = 1.0 # macaron style feed forward module - src = src + self.dropout(self.feed_forward_macaron(src)) + src = self.dropout_ff_macaron(src, self.feed_forward_macaron(src)) # multi-headed self-attention module src_att = self.self_attn( @@ -254,17 +256,13 @@ class ConformerEncoderLayer(nn.Module): attn_mask=src_mask, key_padding_mask=src_key_padding_mask, )[0] - src = src + self.dropout(src_att) + src = self.dropout_self_attn(src, src_att) # convolution module - src = src + self.dropout(self.conv_module(src)) + src = self.dropout_conv(src, self.conv_module(src)) # feed forward module - src = src + self.dropout(self.feed_forward(src)) - - # encourage dimensions of `src` to be un-correlated with each other, this will - # help Adam converge better. - src = self.decorrelate(src) + src = self.dropout_ff(src, self.feed_forward(src)) src = self.norm_final(self.balancer(src)) @@ -1326,6 +1324,9 @@ def _test_random_combine_main(): if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) feature_dim = 50 c = Conformer(num_features=feature_dim, d_model=128, nhead=4) batch_size = 5