From cd6b707e2b334f800766da4b3c3077a9e31ccc2d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 7 Jun 2022 16:45:32 +0800 Subject: [PATCH] Various bug fixes --- .../pruned_transducer_stateless2/scaling.py | 116 +++++++++++++++++- .../pruned_transducer_stateless5/conformer.py | 13 +- 2 files changed, 125 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index a33b3ea36..75a587370 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -713,6 +713,100 @@ class GaussProjDrop(torch.nn.Module): return x +class Decorrelate(torch.nn.Module): + """ + This module is something similar to dropout; it is a random transformation that + does nothing in eval mode. + It is designed specifically to encourage the input data to be decorrelated, i.e. + to have a diagonal covariance matrix (not necessarily unity). + + To save time, in training mode we only apply it on randomly selected minibatches. + + Args: + num_channels: The number of channels, e.g. 256. + apply_prob: The probability with which we apply this each time, in + training mode. This is to save time (but of course it + will tend to make the effect weaker). + dropout_rate: This number determines the scale of the random multiplicative + noise, in such a way that the self-correlation and cross-correlation + statistics match those dropout with the same `dropout_rate` + (assuming we applied the transform, e.g. if apply_prob == 1.0) + eps: An epsilon used to prevent division by zero. + """ + def __init__(self, + apply_prob: float = 0.25, + dropout_rate: float = 0.1, + eps: float = 1.0e-04, + channel_dim: int = -1): + super(Decorrelate, self).__init__() + self.apply_prob = apply_prob + self.dropout_rate = dropout_rate + self.channel_dim = channel_dim + self.eps = eps + + def _get_covar(self, x: Tensor) -> Tensor: + """ + Returns the uncentered covariance matrix associated with feature matrix x, detached + from its input. + Args: + x: Tensor of shape (*, num_channels) + Returns: + Covariance matrix `cov`, of shape (num_channels, num_channels) + """ + x = x.detach() + x = x.reshape(-1, x.shape[-1]) + x = x * (x.shape[0] ** -0.5) # avoid overflow in half precision + return torch.matmul(x.t(), x) + + def _normalize_covar(self, cov: Tensor, eps: float) -> Tensor: + """ + Normlizes a covariance matrix so that its diagonal is 1, by multiplying by + its diagonal**-0.5 on both sides. + Args: + cov: matrix to normalize + eps: floating point value >0, used to prevent division by zero. + """ + diag = cov.diag() + inv_sqrt_diag = (diag + eps) ** -0.5 + cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1)) + return cov + + + def forward(self, x: Tensor) -> Tensor: + if not self.training or random.random() > self.apply_prob: + return x + else: + x = x.transpose(self.channel_dim, -1) # (..., num_channels) + x_bypass = x # will be used for "+ I" + + with torch.cuda.amp.autocast(enabled=False): + cov = self._get_covar(x) + cov = self._normalize_covar(cov, self.eps) + avg_squared_eig = (cov**2).sum(dim=0).mean() + + # the odd-looking formula below was obtained empirically, to match + # the self-product and cross-correlation statistics of dropout + rand_scale = ((self.dropout_rate / (1.0 - self.dropout_rate)) ** 0.5) / avg_squared_eig + + # by multiplying by `cov`, then randomizing the sign of elements, then + # multiplying by `cov` again, we are generating something that has + # more noise in directions corresponding to larger eigenvlues of `cov`. + # (Actually we scale by the square of the eigenvalue, which is not very + # desirable, but was easy to implement in a fast way + x = torch.matmul(x * rand_scale, cov) + rand_mask = (torch.rand_like(x) > 0.5) + # randomize the sign of elements of x. + # important to write the expression this way, so that only rand_mask needs + # to be stored for backprop. + x = x - 2 * (rand_mask * x) + x = torch.matmul(x, cov) + x = x + x_bypass + x = x.transpose(self.channel_dim, -1) + return x + + + + def _test_activation_balancer_sign(): probs = torch.arange(0, 1, 0.01) N = 1000 @@ -805,10 +899,30 @@ def _test_gauss_proj_drop(): m1.eval() m2.eval() +def _test_decorrelate(): + x = torch.randn(30000, 384) + + + for dropout_rate in [0.2, 0.1, 0.01, 0.05]: + m1 = torch.nn.Dropout(dropout_rate) + m2 = Decorrelate(apply_prob=1.0, rand_scale=dropout_rate) + for mode in ['train', 'eval']: + y1 = m1(x) + y2 = m2(x) + xmag = (x*x).mean() + y1mag = (y1*y1).mean() + cross1 = (x*y1).mean() + y2mag = (y2*y2).mean() + cross2 = (x*y2).mean() + print(f"rate={dropout_rate}, mode={mode}, xmag = {xmag}, y1mag = {y1mag}, y2mag = {y2mag}, cross1={cross1}, cross2={cross2}") + m1.eval() + m2.eval() + if __name__ == "__main__": - _test_gauss_proj_drop() + _test_decorrelate() if False: + _test_gauss_proj_drop() _test_activation_balancer_sign() _test_activation_balancer_magnitude() _test_basic_norm() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index e7a46f7be..fd2278781 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -29,7 +29,8 @@ from scaling import ( ScaledConv1d, ScaledConv2d, ScaledLinear, - GaussProjDrop, + Decorrelate, + ) from torch import Tensor, nn @@ -197,7 +198,9 @@ class ConformerEncoderLayer(nn.Module): channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 ) - self.dropout = GaussProjDrop(d_model, dropout) + self.dropout = torch.nn.Dropout(dropout) + self.decorrelate = Decorrelate(apply_prob=0.25, dropout_rate=0.05) + def forward( self, @@ -259,6 +262,10 @@ class ConformerEncoderLayer(nn.Module): # feed forward module src = src + self.dropout(self.feed_forward(src)) + # encourage dimensions of `src` to be un-correlated with each other, this will + # help Adam converge better. + src = self.decorrelate(src) + src = self.norm_final(self.balancer(src)) if alpha != 1.0: @@ -369,7 +376,7 @@ class RelPositionalEncoding(torch.nn.Module): """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model - self.dropout = GaussProjDrop(d_model, dropout_rate) + self.dropout = torch.nn.Dropout(dropout_rate) self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, max_len))