From 7f0756e1567916d45625e5d24a501d2d51cf3de8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 17 Jun 2022 15:10:21 +0800 Subject: [PATCH] Implement structured version of conformer --- .../pruned_transducer_stateless7/conformer.py | 19 +- .../pruned_transducer_stateless7/scaling.py | 169 +++++++++++++++++- 2 files changed, 179 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index f6d698923..491416ec1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -29,6 +29,8 @@ from scaling import ( ScaledConv1d, ScaledConv2d, ScaledLinear, + StructuredConv1d, + StructuredLinear, ) from torch import Tensor, nn @@ -464,7 +466,7 @@ class RelPositionMultiheadAttention(nn.Module): self.head_dim * num_heads == self.embed_dim ), "embed_dim must be divisible by num_heads" - self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) + self.in_proj = StructuredLinear((embed_dim,), (3, embed_dim), bias=True) self.in_balancer = ActivationBalancer(channel_dim=-1, max_abs=5.0) self.proj_balancer = ActivationBalancer(channel_dim=-1, min_positive=0.0, max_positive=1.0, max_abs=10.0) @@ -542,8 +544,8 @@ class RelPositionMultiheadAttention(nn.Module): pos_emb, self.embed_dim, self.num_heads, - self.in_proj.weight, - self.in_proj.bias, + self.in_proj.get_weight(), + self.in_proj.get_bias(), self.dropout, self.out_proj.weight, self.out_proj.bias, @@ -879,9 +881,9 @@ class ConvolutionModule(nn.Module): # kernerl_size should be a odd number for 'SAME' padding assert (kernel_size - 1) % 2 == 0 - self.pointwise_conv1 = nn.Conv1d( - channels, - 2 * channels, + self.pointwise_conv1 = StructuredConv1d( + (channels,), + (2, channels), kernel_size=1, stride=1, padding=0, @@ -1021,8 +1023,9 @@ class Conv2dSubsampling(nn.Module): ActivationBalancer(channel_dim=1), DoubleSwish(), ) - self.out = nn.Linear( - layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels + out_height = (((in_channels - 1) // 2 - 1) // 2) + self.out = StructuredLinear( + (out_height, layer3_channels), (out_channels,) ) # set learn_eps=False because out_norm is preceded by `out`, and `out` # itself has learned scale, so the extra degree of freedom is not diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 71b4db7b3..ed22a6315 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -17,12 +17,14 @@ import collections from itertools import repeat -from typing import Optional, Tuple +from typing import Optional, Tuple, Union +from functools import reduce import logging import random import torch import torch.nn as nn +import torch.nn.functional as F from torch import Tensor from torch.nn import Embedding as ScaledEmbedding @@ -155,6 +157,153 @@ class BasicNorm(torch.nn.Module): return x * scales +class StructuredLinear(torch.nn.Module): + """ + This module mostly behaves like nn.Linear, but the in_features and out_features + (the number of input and output channels) are specified as tuples; the + actual numbers of channels are products over these tuples. + E.g. (2, 256) means 512, with the slowest-varying/largest-stride dims first + in terms of the layout. + For purposes of the forward() function it will behave the same as if the dim + was 512, but the parameter tensors have this structure, which makes + a difference if you are using the NeutralGradient optimizer and perhaps + certain other optimizers. + + Args: + in_features: The number of input channels, specified as + a tuple of ints (the number of input channels will be their + product). The only difference this makes is that the + nn.Parameter tensor will be shaped differently, which may + affect some optimizers. + out_features: The number of output channels, specified as + a tuple of ints. + initial_scale: The default initial parameter scale will be + multiplied by this. + bias: If true, include the bias term. + """ + def __init__(self, + in_features: Tuple[int], + out_features: Tuple[int], + bias: bool = True, + initial_scale: float = 1.0) -> None: + super(StructuredLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + in_size = reduce((lambda i,j: i*j), in_features) + out_size = reduce((lambda i,j: i*j), out_features) + self.weight_shape = (out_size, in_size) + self.weight = nn.Parameter(torch.Tensor(*out_features, *in_features)) + + if bias: + self.bias = nn.Parameter(torch.Tensor(*out_features)) + else: + self.register_parameter('bias', None) + self.reset_parameters(initial_scale) + + + def reset_parameters(self, initial_scale: float = 1.0) -> None: + nn.init.kaiming_uniform_(self.weight.reshape(*self.weight_shape), a=(5 ** 0.5)) + with torch.no_grad(): + self.weight *= initial_scale + nn.init.uniform_(self.bias, + -0.1 * initial_scale, + 0.1 * initial_scale) + + def get_weight(self) -> Tensor: + return self.weight.reshape(*self.weight_shape) + + def get_bias(self) -> Optional[Tensor]: + return (None if self.bias is None else + self.bias.reshape(self.weight_shape[0])) + + def forward(self, input: Tensor) -> Tensor: + return F.linear(input, self.get_weight(), self.get_bias()) + + def extra_repr(self) -> str: + return 'in_features={}, out_features={}, bias={}'.format( + self.in_features, self.out_features, self.bias is not None + ) + + +class StructuredConv1d(nn.Conv1d): + """ + This module mostly behaves like nn.Conv1d, but the + in_channels and out_channels are specified as tuples. For example, + 512 channels might be specified as + (2, 256), with slowest-varying/largest-stride dims first in terms of the layout. + For purposes of the forward() function it will behave the same as if the dim + was 512, but the parameter tensors have this structure, which makes + a difference if you are using the NeutralGradient optimizer. + + + Args: + in_channels: The number of input channels, specified as + a tuple of ints (the number of input channels will be their + product). The only difference this makes is that the + nn.Parameter tensor will be shaped differently, which may + affect some optimizers. + out_channels: The number of output channels, specified as + a tuple of ints. + initial_scale: The default initial parameter scale will be + multiplied by this. + bias: If true, include the bias term. + """ + def __init__( + self, + in_channels: Tuple[int], + out_channels: Tuple[int], + *args, + initial_scale: float = 1.0, + **kwargs + ): + super(StructuredConv1d, self).__init__( + reduce((lambda i,j: i*j), in_channels), + reduce((lambda i,j: i*j), out_channels), + *args, **kwargs) + + assert self.groups == 1, "Groups not supported as yet" + + self.in_channels = in_channels + self.out_channels = out_channels + + if self.transposed: + in_channels, out_channels = out_channels, in_channels + + self.weight_shape = self.weight.shape + self.weight = nn.Parameter(self.weight.detach().reshape( + *out_channels, *in_channels, *self.weight.shape[2:])) + + self.bias_shape = self.bias.shape + if self.bias is not None: + self.bias = nn.Parameter(self.bias.detach().reshape( + *out_channels)) + + # These changes in the initialization are the same as for class ScaledConv1d. + with torch.no_grad(): + self.weight[:] *= initial_scale + if self.bias is not None: + torch.nn.init.uniform_(self.bias, + -0.1 * initial_scale, + 0.1 * initial_scale) + + def get_weight(self) -> Tensor: + return self.weight.reshape(*self.weight_shape) + def get_bias(self) -> Optional[Tensor]: + return (None if self.bias is None else + self.bias.reshape(*self.bias_shape)) + + def forward(self, input: Tensor) -> Tensor: + if self.padding_mode != 'zeros': + return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + self.get_weight(), self.get_bias(), self.stride, + _single(0), self.dilation, self.groups) + return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride, + self.padding, self.dilation, self.groups) + + + + + class ScaledLinear(nn.Linear): """ A modified version of nn.Linear that gives an easy way to set the @@ -484,6 +633,22 @@ def _test_double_swish_deriv(): m = DoubleSwish() torch.autograd.gradcheck(m, x) +def _test_structured_linear(): + m = StructuredLinear((2, 100), (3, 100), bias=True) + assert m.weight.shape == (3, 100, 2, 100) + assert m.bias.shape == (3, 100) + x = torch.randn(50, 200) + y = m(x) + assert y.shape == (50, 300) + +def _test_structured_conv1d(): + m = StructuredConv1d((2, 100), (3, 100), kernel_size=3, padding=1, bias=True) + assert m.weight.shape == (3, 100, 2, 100, 3) + assert m.bias.shape == (3, 100) + T = 39 + x = torch.randn(50, 200, T) + y = m(x) + assert y.shape == (50, 300, T) def _test_gauss_proj_drop(): D = 384 @@ -511,6 +676,8 @@ if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) torch.set_num_interop_threads(1) + _test_structured_linear() + _test_structured_conv1d() _test_gauss_proj_drop() _test_activation_balancer_sign() _test_activation_balancer_magnitude()