From 31848dcd11dcceaa5ed8833991c7b277affefb9f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 6 Jun 2022 16:05:18 +0800 Subject: [PATCH] Randomize the projections --- .../pruned_transducer_stateless2/scaling.py | 19 +++++++++++++++---- .../pruned_transducer_stateless5/conformer.py | 6 +++--- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 4ab34743f..25e3fdccb 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -19,6 +19,7 @@ import collections from itertools import repeat from typing import Optional, Tuple +import random import torch import torch.nn as nn from torch import Tensor @@ -639,7 +640,7 @@ class ScaledEmbedding(nn.Module): return s.format(**self.__dict__) -class FixedProjDrop(torch.nn.Module): +class RandProjDrop(torch.nn.Module): """ This has an effect similar to torch.nn.Dropout, but does not privilege the on-axis directions. The directions of dropout are fixed when the class is initialized, and are orthogonal. @@ -651,7 +652,7 @@ class FixedProjDrop(torch.nn.Module): num_channels: int, dropout_rate: float = 0.1, channel_dim: int = -1): - super(FixedProjDrop, self).__init__() + super(RandProjDrop, self).__init__() self.dropout_rate = dropout_rate self.channel_dim = channel_dim @@ -665,6 +666,7 @@ class FixedProjDrop(torch.nn.Module): x = torch.nn.functional.dropout(x, self.dropout_rate, training=False) else: + self._randomize_U() x = x.transpose(self.channel_dim, -1) # (..., num_channels) x = torch.matmul(x, self.U) @@ -674,7 +676,16 @@ class FixedProjDrop(torch.nn.Module): x = x.transpose(self.channel_dim, -1) # (..., num_channels) return x - + def _randomize_U(self): + dim = random.randint(0, 1) + U = self.U + num_channels = U.shape[0] + # pick place to split U in two pieces. + r = random.randint(1, num_channels - 2) + U_part1 = U.narrow(dim, 0, r) + U_part2 = U.narrow(dim, r, num_channels-r) + U = torch.cat((U_part2, U_part1), dim=dim) + self.U[:] = U def _test_activation_balancer_sign(): probs = torch.arange(0, 1, 0.01) @@ -751,7 +762,7 @@ def _test_double_swish_deriv(): def _test_proj_drop(): x = torch.randn(30000, 300) - m = FixedProjDrop(300, 0.1) + m = RandProjDrop(300, 0.1) y = m(x) xmag = (x*x).mean() ymag = (y*y).mean() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 47940cc81..77f00cc63 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -29,7 +29,7 @@ from scaling import ( ScaledConv1d, ScaledConv2d, ScaledLinear, - FixedProjDrop, + RandProjDrop, ) from torch import Tensor, nn @@ -197,7 +197,7 @@ class ConformerEncoderLayer(nn.Module): channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 ) - self.dropout = FixedProjDrop(d_model, dropout) + self.dropout = RandProjDrop(d_model, dropout) def forward( self, @@ -369,7 +369,7 @@ class RelPositionalEncoding(torch.nn.Module): """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model - self.dropout = FixedProjDrop(d_model, dropout_rate) + self.dropout = RandProjDrop(d_model, dropout_rate) self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, max_len))