From e535887abb6d8ccff64de8278d51e920306529c1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 5 Jun 2022 23:24:02 +0800 Subject: [PATCH] Bug fixes. --- .../pruned_transducer_stateless2/scaling.py | 29 ++++++++++--------- .../pruned_transducer_stateless5/conformer.py | 2 +- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index f33286335..cc832e1bc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -655,9 +655,9 @@ class ProjDrop(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: - - if self.training: - return x * (1.0 - self.dropout_rate) + if not self.training: + # The ** 0.5 is intended to reproduce the scale on (x**2).sum(). + return x * ((1.0 - self.dropout_rate) ** 0.5) else: x = x.transpose(self.channel_dim, -1) # (..., num_channels) num_channels = x.shape[-1] @@ -667,15 +667,14 @@ class ProjDrop(torch.nn.Module): device=x.device, dtype=x.dtype) rr = torch.matmul(r, r.t()) # num_dropped by num_dropped rr += 0.01 # to 100% ensure it is invertible - rr_inv = rr.cholesky.cholesky_inverse() + rr_inv = rr.cholesky().cholesky_inverse() # OK, so r rr_inv r.t() will have eigenvalues of 1. xr = torch.matmul(x, r.t()) # (..., num_dropped) rr_inv_r = torch.matmul(rr_inv, r) # (num_dropped, num_channels) xrr = torch.matmul(xr, rr_inv_r) # (..., num_channels) - return x - xrr - - - + x = x - xrr + x = x.transpose(self.channel_dim, -1) + return x def _test_activation_balancer_sign(): @@ -752,14 +751,18 @@ def _test_double_swish_deriv(): def _test_proj_drop(): - x = torch.randn(3000, 300) + x = torch.randn(30000, 300) m = ProjDrop(0.1) y = m(x) - xmag = (x*x).sqrt().mean() - ymag = (y*y).sqrt().mean() + xmag = (x*x).mean() + ymag = (y*y).mean() print(f"xmag = {xmag}, ymag = {ymag}") - assert abs((ymag / xmag) - 0.9) < 0.01 - + assert abs((ymag / xmag) - 0.9) < 0.02 + m.eval() + y = m(x) + ymag = (y*y).mean() + print(f"xmag[eval] = {xmag}, ymag = {ymag}") + assert abs((ymag / xmag) - 0.9) < 0.02 if __name__ == "__main__": _test_proj_drop() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index ebec92bf6..9f46e14f9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -369,7 +369,7 @@ class RelPositionalEncoding(torch.nn.Module): """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model - self.dropout = ProjDrop(p=dropout_rate) + self.dropout = ProjDrop(dropout_rate) self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, max_len))