diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 022be8053..cf2f05999 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -27,10 +27,7 @@ from scaling import ( BasicNorm, DoubleSwish, ScaledConv1d, - ScaledConv2d, - ScaledLinear, - StructuredConv1d, - StructuredLinear, + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. ) from torch import Tensor, nn @@ -1023,9 +1020,7 @@ class Conv2dSubsampling(nn.Module): DoubleSwish(), ) out_height = (((in_channels - 1) // 2 - 1) // 2) - self.out = StructuredLinear( - (out_height, layer3_channels), (out_channels,) - ) + self.out = nn.Linear(out_height * layer3_channels, out_channels) # set learn_eps=False because out_norm is preceded by `out`, and `out` # itself has learned scale, so the extra degree of freedom is not # needed. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 8ae390f45..9b2c7a19d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -314,11 +314,12 @@ class StructuredConv1d(nn.Conv1d): - -class ScaledLinear(nn.Linear): +def ScaledLinear(*args, + initial_scale: float = 1.0, + **kwargs ) -> nn.Linear: """ - A modified version of nn.Linear that gives an easy way to set the - default initial parameter scale. + Behaves like a constructor of a modified version of nn.Linear + that gives an easy way to set the default initial parameter scale. Args: Accepts the standard args and kwargs that nn.Linear accepts @@ -330,67 +331,42 @@ class ScaledLinear(nn.Linear): Another option, if you want to do something like this, is to re-initialize the parameters. """ - - def __init__( - self, - *args, - initial_scale: float = 1.0, - **kwargs - ): - super(ScaledLinear, self).__init__(*args, **kwargs) - with torch.no_grad(): - self.weight[:] *= initial_scale - if self.bias is not None: - torch.nn.init.uniform_(self.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + ans = nn.Linear(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, + -0.1 * initial_scale, + 0.1 * initial_scale) + return ans -class ScaledConv1d(nn.Conv1d): - # See docs for ScaledLinear - def __init__( - self, - *args, - initial_scale: float = 1.0, - **kwargs - ): - super(ScaledConv1d, self).__init__(*args, **kwargs) - with torch.no_grad(): - self.weight[:] *= initial_scale - if self.bias is not None: - torch.nn.init.uniform_(self.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) +def ScaledConv1d(*args, + initial_scale: float = 1.0, + **kwargs ) -> nn.Linear: + """ + Behaves like a constructor of a modified version of nn.Conv1d + that gives an easy way to set the default initial parameter scale. - def get_weight(self): # TODO: delete - return self.weight + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. - def get_bias(self): # TODO: delete - return self.bias - - -class ScaledConv2d(nn.Conv2d): - # See docs for ScaledLinear - def __init__( - self, - *args, - initial_scale: float = 1.0, - **kwargs - ): - super(ScaledConv2d, self).__init__(*args, **kwargs) - with torch.no_grad(): - self.weight[:] *= initial_scale - if self.bias is not None: - torch.nn.init.uniform_(self.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) - - def get_weight(self): - return self.weight - - def get_bias(self): - return self.bias + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Conv1d(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, + -0.1 * initial_scale, + 0.1 * initial_scale) + return ans @@ -497,80 +473,6 @@ class DoubleSwish(torch.nn.Module): -class GaussProjDrop(torch.nn.Module): - """ - This has an effect similar to torch.nn.Dropout, but does not privilege the on-axis directions. - The directions of dropout are fixed when the class is initialized, and are orthogonal. - - dropout_rate: the dropout probability (actually will define the number of zeroed-out directions) - channel_dim: the axis corresponding to the channel, e.g. -1, 0, 1, 2. - """ - def __init__(self, - num_channels: int, - dropout_rate: float = 0.1, - channel_dim: int = -1): - super(GaussProjDrop, self).__init__() - self.dropout_rate = dropout_rate - # this formula for rand_scale was found empirically, trying to match the - # statistics of dropout in terms of cross-correlation with the input, see - # _test_gauss_proj_drop() - self.rand_scale = (dropout_rate / (1-dropout_rate)) ** 0.5 # * (num_channels ** -0.5) - - self.channel_dim = channel_dim - - rand_mat = torch.randn(num_channels, num_channels) - U, _, _ = rand_mat.svd() - self.register_buffer('U', U) # a random orthogonal square matrix. will be a buffer. - - - def _randperm_like(self, x: Tensor): - """ - Returns random permutations of the integers [0,1,..x.shape[-1]-1], - with the same shape as x. All dimensions of x other than the last dimension - will be treated as batch dimensions. - - Torch's randperm does not support a batch dimension, so we pseudo-randomly simulate it. - - For now, requires x.shape[-1] to be either a power of 2 or 3 times a power of 2, as - we normally set channel dims. This is required for some number theoretic stuff. - """ - n = x.shape[-1] - - assert n & (n-1) == 0 or (n//3 & (n//3 - 1)) == 0 - - b = x.numel() // n - randint = random.randint(0, 1000) - perm = torch.randperm(n, device=x.device) - # ensure all elements of batch_rand are coprime to n; this will ensure - # that multiplying the permutation by batch_rand and taking modulo - # n leaves us with permutations. - batch_rand = torch.arange(b, device=x.device) * (randint * 6) + 1 - batch_rand = batch_rand.unsqueeze(-1) - ans = (perm * batch_rand) % n - ans = ans.reshape(x.shape) - return ans - - - def forward(self, x: Tensor) -> Tensor: - if not self.training: - return x - else: - x = x.transpose(self.channel_dim, -1) # (..., num_channels) - x_bypass = x # will be used for "+ I" - perm = self._randperm_like(x) - x = torch.gather(x, -1, perm) - # self.U will act like a different matrix for every row of x, because of the random - # permutation. - x = torch.matmul(x, self.U) - x_next = torch.empty_like(x) - # scatter_ uses perm in opposite way - # from gather, inverting it. - x_next.scatter_(-1, perm, x) - x = (x_next * self.rand_scale + x_bypass) - return x - - - def _test_activation_balancer_sign(): probs = torch.arange(0, 1, 0.01) @@ -644,52 +546,11 @@ def _test_double_swish_deriv(): m = DoubleSwish() torch.autograd.gradcheck(m, x) -def _test_structured_linear(): - m = StructuredLinear((2, 100), (3, 100), bias=True) - assert m.weight.shape == (3, 100, 2, 100) - assert m.bias.shape == (3, 100) - x = torch.randn(50, 200) - y = m(x) - assert y.shape == (50, 300) - -def _test_structured_conv1d(): - m = StructuredConv1d((2, 100), (3, 100), kernel_size=3, padding=1, bias=True) - assert m.weight.shape == (3, 100, 2, 100, 3) - assert m.bias.shape == (3, 100) - T = 39 - x = torch.randn(50, 200, T) - y = m(x) - assert y.shape == (50, 300, T) - -def _test_gauss_proj_drop(): - D = 384 - x = torch.randn(30000, D) - - - for dropout_rate in [0.2, 0.1, 0.01, 0.05]: - m1 = torch.nn.Dropout(dropout_rate) - m2 = GaussProjDrop(D, dropout_rate) - for mode in ['train', 'eval']: - y1 = m1(x) - y2 = m2(x) - xmag = (x*x).mean() - y1mag = (y1*y1).mean() - cross1 = (x*y1).mean() - y2mag = (y2*y2).mean() - cross2 = (x*y2).mean() - print(f"rate={dropout_rate}, mode={mode}, xmag = {xmag}, y1mag = {y1mag}, y2mag = {y2mag}, cross1={cross1}, cross2={cross2}") - m1.eval() - m2.eval() - - if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) torch.set_num_interop_threads(1) - _test_structured_linear() - _test_structured_conv1d() - _test_gauss_proj_drop() _test_activation_balancer_sign() _test_activation_balancer_magnitude() _test_basic_norm()