diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 5ee4bab98..f33286335 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -639,6 +639,45 @@ class ScaledEmbedding(nn.Module): return s.format(**self.__dict__) +class ProjDrop(torch.nn.Module): + """ + This has an effect similar to torch.nn.Dropout, but does not privilege the on-axis directions. + + dropout_rate: the dropout probability (actually will define the number of zeroed-out directions) + channel_dim: the axis corresponding to the channel, e.g. -1, 0, 1, 2. + """ + def __init__(self, + dropout_rate: float = 0.1, + channel_dim: int = -1): + super(ProjDrop, self).__init__() + self.dropout_rate = dropout_rate + self.channel_dim = channel_dim + + + def forward(self, x: Tensor) -> Tensor: + + if self.training: + return x * (1.0 - self.dropout_rate) + else: + x = x.transpose(self.channel_dim, -1) # (..., num_channels) + num_channels = x.shape[-1] + num_dropped = int(self.dropout_rate * num_channels) + + r = torch.randn(num_dropped, num_channels, + device=x.device, dtype=x.dtype) + rr = torch.matmul(r, r.t()) # num_dropped by num_dropped + rr += 0.01 # to 100% ensure it is invertible + rr_inv = rr.cholesky.cholesky_inverse() + # OK, so r rr_inv r.t() will have eigenvalues of 1. + xr = torch.matmul(x, r.t()) # (..., num_dropped) + rr_inv_r = torch.matmul(rr_inv, r) # (num_dropped, num_channels) + xrr = torch.matmul(xr, rr_inv_r) # (..., num_channels) + return x - xrr + + + + + def _test_activation_balancer_sign(): probs = torch.arange(0, 1, 0.01) N = 1000 @@ -712,7 +751,18 @@ def _test_double_swish_deriv(): torch.autograd.gradcheck(m, x) +def _test_proj_drop(): + x = torch.randn(3000, 300) + m = ProjDrop(0.1) + y = m(x) + xmag = (x*x).sqrt().mean() + ymag = (y*y).sqrt().mean() + print(f"xmag = {xmag}, ymag = {ymag}") + assert abs((ymag / xmag) - 0.9) < 0.01 + + if __name__ == "__main__": + _test_proj_drop() _test_activation_balancer_sign() _test_activation_balancer_magnitude() _test_basic_norm() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 6f7231f4b..ebec92bf6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -29,6 +29,7 @@ from scaling import ( ScaledConv1d, ScaledConv2d, ScaledLinear, + ProjDrop, ) from torch import Tensor, nn @@ -196,7 +197,7 @@ class ConformerEncoderLayer(nn.Module): channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 ) - self.dropout = nn.Dropout(dropout) + self.dropout = ProjDrop(dropout) def forward( self, @@ -368,7 +369,7 @@ class RelPositionalEncoding(torch.nn.Module): """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model - self.dropout = torch.nn.Dropout(p=dropout_rate) + self.dropout = ProjDrop(p=dropout_rate) self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, max_len))