From a6050cb2de0f66e17104e350ae59091bb89bb606 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 7 Jun 2022 23:38:38 +0800 Subject: [PATCH] Implement new, more principled but maybe slower version. --- .../pruned_transducer_stateless2/scaling.py | 152 ++++++++++++------ .../pruned_transducer_stateless5/conformer.py | 2 +- 2 files changed, 103 insertions(+), 51 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index ff5066a0a..8ff009a3f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -736,52 +736,111 @@ class Decorrelate(torch.nn.Module): max_dropout_rate: This is an upper limit, for safety, on how aggressive the randomization can be. eps: An epsilon used to prevent division by zero. + beta: A value 0 < beta < 1 that controls decay of covariance stats """ def __init__(self, num_channels: int, apply_prob: float = 0.25, - dropout_rate: float = 0.01, - max_dropout_rate: float = 0.1, + dropout_rate: float = 0.1, eps: float = 1.0e-04, + beta: float = 0.95, channel_dim: int = -1): super(Decorrelate, self).__init__() self.apply_prob = apply_prob self.dropout_rate = dropout_rate - self.max_dropout_rate = max_dropout_rate self.channel_dim = channel_dim self.eps = eps + self.beta = beta rand_mat = torch.randn(num_channels, num_channels) U, _, _ = rand_mat.svd() + self.register_buffer('U', U) # a random orthogonal square matrix. will be a buffer. + self.register_buffer('T1', torch.eye(num_channels)) + self.register_buffer('rand_scales', torch.zeros(num_channels)) + self.register_buffer('nonrand_scales', torch.ones(num_channels)) + self.register_buffer('T2', torch.eye(num_channels)) + self.register_buffer('cov', torch.zeros(num_channels, num_channels)) + self.step = 0 - def _get_covar(self, x: Tensor) -> Tensor: + + + def _update_covar_stats(self, x: Tensor) -> None: """ - Returns the uncentered covariance matrix associated with feature matrix x, detached - from its input. Args: x: Tensor of shape (*, num_channels) - Returns: - Covariance matrix `cov`, of shape (num_channels, num_channels) + Updates covariance stats self.cov """ x = x.detach() x = x.reshape(-1, x.shape[-1]) x = x * (x.shape[0] ** -0.5) # avoid overflow in half precision - return torch.matmul(x.t(), x) + cov = torch.matmul(x.t(), x) + self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta)) + self.step += 1 - def _normalize_covar(self, cov: Tensor, eps: float) -> Tensor: + def _update_transforms(self): + + norm_cov, inv_sqrt_diag = self._normalize_covar(self.cov) + + U, S, _ = norm_cov.svd() + + if random.random() < 0.1: + print("Decorrelate: max,min eig of normalized cov is: {S.max().item():.2e},{S.min().item():.2e}") + + # row indexes of U correspond to channels, column indexes correspond to + # singular values: cov = U * diag(S) * U.t() where * is matmul. + S_eps = S + self.eps + S_sqrt = S_eps ** 0.5 + S_inv_sqrt = (S + self.eps) ** -0.5 + + # Transform T1, which we'll incorporate as torch.matmul(x, self.T1), is: + # (i) multiply by inv_sqrt_diag which makes the covariance have + # a unit diagonal. + # (ii) multiply by U, which diagonalizes norm_cov (uncorrelated channels) + # (iii) divide by S_sqrt, which makes all dims have unit variance. + self.T1[:] = (inv_sqrt_diag.unsqueeze(-1) * U / S_sqrt) + + # Transform T1, which we'll incorporate as torch.matmul(x, self.TT), is: + # (i) multiply by S_sqrt, which restors the variance of different dims, + # (ii) multiply by U, which diagonalizes norm_cov (uncorrelated channels) + # (iii) divide by inv_sqrt_diag which makes the covariance have its original + # diagonal values. + self.T2[:] = (S_sqrt.unsqueeze(-1) * U.t() / inv_sqrt_diag) + + + # OK, now get rand_scales, which are values between 0 and self.dropout_rate; it says + # how much randomness will be in different eigenvalues of norm_cov. + # Basically, we want more randomness in directions with eigenvalues more than one, + # and none in those with eigenvalues less than one. + rand_proportion = (S - 1.0).clamp(min=0.0, max=1.0) * self.dropout_rate + + # rand_proportion is viewed as representing a proportion of the covariance, since + # the random and nonrandom components will not be correlated. + self.rand_scales = rand_proportion.sqrt() + self.nonrand_scales = (1.0 - rand_proportion).sqrt() + + + if True: + d = torch.matmul(self.T1, self.T2) - torch.eye(self.T1.shape[0], + device=self.T1.device, + dtype=self.T1.dtype) + assert torch.all(d.abs() < 0.01) + + + + def _normalize_covar(self, cov: Tensor) -> Tensor: """ Normlizes a covariance matrix so that its diagonal is 1, by multiplying by its diagonal**-0.5 on both sides. Args: cov: matrix to normalize - eps: floating point value >0, used to prevent division by zero. Returns normalized_cov, inv_sqrt_diag """ diag = cov.diag() - inv_sqrt_diag = (diag + eps) ** -0.5 + inv_sqrt_diag = (diag + self.eps) ** -0.5 cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1)) + assert torch.all((cov.diag() - 1.0).abs() < 0.1) # TODO: remove return cov, inv_sqrt_diag @@ -818,46 +877,33 @@ class Decorrelate(torch.nn.Module): return x else: x = x.transpose(self.channel_dim, -1) # (..., num_channels) - x_bypass = x # will be used for "+ I" with torch.cuda.amp.autocast(enabled=False): - cov = self._get_covar(x) - cov, inv_sqrt_diag = self._normalize_covar(cov, self.eps) - avg_squared_eig = (cov**2).sum(dim=0).mean() - if random.random() < 0.001 or __name__ == "__main__": - logging.info(f"Decorrelate: avg_squared_eig = {avg_squared_eig}") + self._update_covar_stats(x) + if self.step % 50 == 0 or __name__ == "__main__": + self._update_transforms() - # the odd-looking formula below was obtained empirically, to match - # the self-product and cross-correlation statistics of dropout + x = torch.matmul(x, self.T1) - x = x * inv_sqrt_diag + x_bypass = x - rand_scale1 = ((self.max_dropout_rate / (1.0 - self.max_dropout_rate)) ** 0.5) / avg_squared_eig - rand_scale2 = ((self.dropout_rate / (1.0 - self.dropout_rate)) ** 0.5) - rand_scale = torch.minimum(rand_scale1, torch.tensor(rand_scale2, device=x.device)) + if True: + # This block, in effect, multiplies x by a random orthogonal matrix, + # giving us random noise. + perm = self._randperm_like(x) + x = torch.gather(x, -1, perm) + # self.U will act like a different matrix for every row of x, + # because of the random permutation. + x = torch.matmul(x, self.U) + x_next = torch.empty_like(x) + # scatter_ uses perm in opposite way + # from gather, inverting it. + x_next.scatter_(-1, perm, x) + x = x_next - # by multiplying by `cov`, then randomizing the sign of elements, then - # multiplying by `cov` again, we are generating something that has - # more noise in directions corresponding to larger eigenvalues of `cov`. - # (Actually we scale by the square of the eigenvalue, which is not very - # desirable, but was easy to implement in a fast way - x = torch.matmul(x * rand_scale, cov) + x = (x * self.rand_scales) + (x_bypass * self.nonrand_scales) - perm = self._randperm_like(x) - x = torch.gather(x, -1, perm) - # self.U will act like a different matrix for every row of x, - # because of the random permutation. - x = torch.matmul(x, self.U) - x_next = torch.empty_like(x) - # scatter_ uses perm in opposite way - # from gather, inverting it. - x_next.scatter_(-1, perm, x) - x = x_next - - x = torch.matmul(x, cov) - x = x / inv_sqrt_diag - - x = x + x_bypass + x = torch.matmul(x, self.T2) x = x.transpose(self.channel_dim, -1) return x @@ -938,12 +984,13 @@ def _test_double_swish_deriv(): def _test_gauss_proj_drop(): - x = torch.randn(30000, 384) + D = 384 + x = torch.randn(30000, D) for dropout_rate in [0.2, 0.1, 0.01, 0.05]: m1 = torch.nn.Dropout(dropout_rate) - m2 = GaussProjDrop(384, dropout_rate) + m2 = GaussProjDrop(D, dropout_rate) for mode in ['train', 'eval']: y1 = m1(x) y2 = m2(x) @@ -958,12 +1005,17 @@ def _test_gauss_proj_drop(): def _test_decorrelate(): logging.getLogger().setLevel(logging.INFO) - x = torch.randn(30000, 384) + D = 384 + x = torch.randn(30000, D) + + # give it a non-unit covariance. + m = torch.randn(D, D) + x = torch.matmul(x, m) for dropout_rate in [0.2, 0.1, 0.01, 0.05]: m1 = torch.nn.Dropout(dropout_rate) - m2 = Decorrelate(384, apply_prob=1.0, dropout_rate=dropout_rate, max_dropout_rate=dropout_rate) + m2 = Decorrelate(D, apply_prob=1.0, dropout_rate=dropout_rate) for mode in ['train', 'eval']: y1 = m1(x) y2 = m2(x) @@ -972,7 +1024,7 @@ def _test_decorrelate(): cross1 = (x*y1).mean() y2mag = (y2*y2).mean() cross2 = (x*y2).mean() - print(f"rate={dropout_rate}, mode={mode}, xmag = {xmag}, y1mag = {y1mag}, y2mag = {y2mag}, cross1={cross1}, cross2={cross2}") + print(f"rate={dropout_rate}, mode={mode}, xmag = {xmag}, y1mag = {y1mag}, y2mag = {y2mag}, cross1={cross1}, cross2={cross2}, ratio1={y1mag/cross1}, ratio2={y2mag/cross2}") m1.eval() m2.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 00ff3f3cc..f7fd6ce61 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -199,7 +199,7 @@ class ConformerEncoderLayer(nn.Module): ) self.dropout = torch.nn.Dropout(dropout) - self.decorrelate = Decorrelate(d_model, apply_prob=0.25) + self.decorrelate = Decorrelate(d_model, apply_prob=0.25, dropout_rate=0.2) def forward(