Implement structured version of conformer

This commit is contained in:
Daniel Povey 2022-06-17 15:10:21 +08:00
parent 2fe4af8c99
commit 7f0756e156
2 changed files with 179 additions and 9 deletions

View File

@ -29,6 +29,8 @@ from scaling import (
ScaledConv1d,
ScaledConv2d,
ScaledLinear,
StructuredConv1d,
StructuredLinear,
)
from torch import Tensor, nn
@ -464,7 +466,7 @@ class RelPositionMultiheadAttention(nn.Module):
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
self.in_proj = StructuredLinear((embed_dim,), (3, embed_dim), bias=True)
self.in_balancer = ActivationBalancer(channel_dim=-1, max_abs=5.0)
self.proj_balancer = ActivationBalancer(channel_dim=-1, min_positive=0.0,
max_positive=1.0, max_abs=10.0)
@ -542,8 +544,8 @@ class RelPositionMultiheadAttention(nn.Module):
pos_emb,
self.embed_dim,
self.num_heads,
self.in_proj.weight,
self.in_proj.bias,
self.in_proj.get_weight(),
self.in_proj.get_bias(),
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
@ -879,9 +881,9 @@ class ConvolutionModule(nn.Module):
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
self.pointwise_conv1 = StructuredConv1d(
(channels,),
(2, channels),
kernel_size=1,
stride=1,
padding=0,
@ -1021,8 +1023,9 @@ class Conv2dSubsampling(nn.Module):
ActivationBalancer(channel_dim=1),
DoubleSwish(),
)
self.out = nn.Linear(
layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels
out_height = (((in_channels - 1) // 2 - 1) // 2)
self.out = StructuredLinear(
(out_height, layer3_channels), (out_channels,)
)
# set learn_eps=False because out_norm is preceded by `out`, and `out`
# itself has learned scale, so the extra degree of freedom is not

View File

@ -17,12 +17,14 @@
import collections
from itertools import repeat
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
from functools import reduce
import logging
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Embedding as ScaledEmbedding
@ -155,6 +157,153 @@ class BasicNorm(torch.nn.Module):
return x * scales
class StructuredLinear(torch.nn.Module):
"""
This module mostly behaves like nn.Linear, but the in_features and out_features
(the number of input and output channels) are specified as tuples; the
actual numbers of channels are products over these tuples.
E.g. (2, 256) means 512, with the slowest-varying/largest-stride dims first
in terms of the layout.
For purposes of the forward() function it will behave the same as if the dim
was 512, but the parameter tensors have this structure, which makes
a difference if you are using the NeutralGradient optimizer and perhaps
certain other optimizers.
Args:
in_features: The number of input channels, specified as
a tuple of ints (the number of input channels will be their
product). The only difference this makes is that the
nn.Parameter tensor will be shaped differently, which may
affect some optimizers.
out_features: The number of output channels, specified as
a tuple of ints.
initial_scale: The default initial parameter scale will be
multiplied by this.
bias: If true, include the bias term.
"""
def __init__(self,
in_features: Tuple[int],
out_features: Tuple[int],
bias: bool = True,
initial_scale: float = 1.0) -> None:
super(StructuredLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
in_size = reduce((lambda i,j: i*j), in_features)
out_size = reduce((lambda i,j: i*j), out_features)
self.weight_shape = (out_size, in_size)
self.weight = nn.Parameter(torch.Tensor(*out_features, *in_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(*out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters(initial_scale)
def reset_parameters(self, initial_scale: float = 1.0) -> None:
nn.init.kaiming_uniform_(self.weight.reshape(*self.weight_shape), a=(5 ** 0.5))
with torch.no_grad():
self.weight *= initial_scale
nn.init.uniform_(self.bias,
-0.1 * initial_scale,
0.1 * initial_scale)
def get_weight(self) -> Tensor:
return self.weight.reshape(*self.weight_shape)
def get_bias(self) -> Optional[Tensor]:
return (None if self.bias is None else
self.bias.reshape(self.weight_shape[0]))
def forward(self, input: Tensor) -> Tensor:
return F.linear(input, self.get_weight(), self.get_bias())
def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None
)
class StructuredConv1d(nn.Conv1d):
"""
This module mostly behaves like nn.Conv1d, but the
in_channels and out_channels are specified as tuples. For example,
512 channels might be specified as
(2, 256), with slowest-varying/largest-stride dims first in terms of the layout.
For purposes of the forward() function it will behave the same as if the dim
was 512, but the parameter tensors have this structure, which makes
a difference if you are using the NeutralGradient optimizer.
Args:
in_channels: The number of input channels, specified as
a tuple of ints (the number of input channels will be their
product). The only difference this makes is that the
nn.Parameter tensor will be shaped differently, which may
affect some optimizers.
out_channels: The number of output channels, specified as
a tuple of ints.
initial_scale: The default initial parameter scale will be
multiplied by this.
bias: If true, include the bias term.
"""
def __init__(
self,
in_channels: Tuple[int],
out_channels: Tuple[int],
*args,
initial_scale: float = 1.0,
**kwargs
):
super(StructuredConv1d, self).__init__(
reduce((lambda i,j: i*j), in_channels),
reduce((lambda i,j: i*j), out_channels),
*args, **kwargs)
assert self.groups == 1, "Groups not supported as yet"
self.in_channels = in_channels
self.out_channels = out_channels
if self.transposed:
in_channels, out_channels = out_channels, in_channels
self.weight_shape = self.weight.shape
self.weight = nn.Parameter(self.weight.detach().reshape(
*out_channels, *in_channels, *self.weight.shape[2:]))
self.bias_shape = self.bias.shape
if self.bias is not None:
self.bias = nn.Parameter(self.bias.detach().reshape(
*out_channels))
# These changes in the initialization are the same as for class ScaledConv1d.
with torch.no_grad():
self.weight[:] *= initial_scale
if self.bias is not None:
torch.nn.init.uniform_(self.bias,
-0.1 * initial_scale,
0.1 * initial_scale)
def get_weight(self) -> Tensor:
return self.weight.reshape(*self.weight_shape)
def get_bias(self) -> Optional[Tensor]:
return (None if self.bias is None else
self.bias.reshape(*self.bias_shape))
def forward(self, input: Tensor) -> Tensor:
if self.padding_mode != 'zeros':
return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
self.get_weight(), self.get_bias(), self.stride,
_single(0), self.dilation, self.groups)
return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride,
self.padding, self.dilation, self.groups)
class ScaledLinear(nn.Linear):
"""
A modified version of nn.Linear that gives an easy way to set the
@ -484,6 +633,22 @@ def _test_double_swish_deriv():
m = DoubleSwish()
torch.autograd.gradcheck(m, x)
def _test_structured_linear():
m = StructuredLinear((2, 100), (3, 100), bias=True)
assert m.weight.shape == (3, 100, 2, 100)
assert m.bias.shape == (3, 100)
x = torch.randn(50, 200)
y = m(x)
assert y.shape == (50, 300)
def _test_structured_conv1d():
m = StructuredConv1d((2, 100), (3, 100), kernel_size=3, padding=1, bias=True)
assert m.weight.shape == (3, 100, 2, 100, 3)
assert m.bias.shape == (3, 100)
T = 39
x = torch.randn(50, 200, T)
y = m(x)
assert y.shape == (50, 300, T)
def _test_gauss_proj_drop():
D = 384
@ -511,6 +676,8 @@ if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
_test_structured_linear()
_test_structured_conv1d()
_test_gauss_proj_drop()
_test_activation_balancer_sign()
_test_activation_balancer_magnitude()