mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Implement structured version of conformer
This commit is contained in:
parent
2fe4af8c99
commit
7f0756e156
@ -29,6 +29,8 @@ from scaling import (
|
|||||||
ScaledConv1d,
|
ScaledConv1d,
|
||||||
ScaledConv2d,
|
ScaledConv2d,
|
||||||
ScaledLinear,
|
ScaledLinear,
|
||||||
|
StructuredConv1d,
|
||||||
|
StructuredLinear,
|
||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
@ -464,7 +466,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
self.head_dim * num_heads == self.embed_dim
|
self.head_dim * num_heads == self.embed_dim
|
||||||
), "embed_dim must be divisible by num_heads"
|
), "embed_dim must be divisible by num_heads"
|
||||||
|
|
||||||
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
|
self.in_proj = StructuredLinear((embed_dim,), (3, embed_dim), bias=True)
|
||||||
self.in_balancer = ActivationBalancer(channel_dim=-1, max_abs=5.0)
|
self.in_balancer = ActivationBalancer(channel_dim=-1, max_abs=5.0)
|
||||||
self.proj_balancer = ActivationBalancer(channel_dim=-1, min_positive=0.0,
|
self.proj_balancer = ActivationBalancer(channel_dim=-1, min_positive=0.0,
|
||||||
max_positive=1.0, max_abs=10.0)
|
max_positive=1.0, max_abs=10.0)
|
||||||
@ -542,8 +544,8 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
pos_emb,
|
pos_emb,
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.in_proj.weight,
|
self.in_proj.get_weight(),
|
||||||
self.in_proj.bias,
|
self.in_proj.get_bias(),
|
||||||
self.dropout,
|
self.dropout,
|
||||||
self.out_proj.weight,
|
self.out_proj.weight,
|
||||||
self.out_proj.bias,
|
self.out_proj.bias,
|
||||||
@ -879,9 +881,9 @@ class ConvolutionModule(nn.Module):
|
|||||||
# kernerl_size should be a odd number for 'SAME' padding
|
# kernerl_size should be a odd number for 'SAME' padding
|
||||||
assert (kernel_size - 1) % 2 == 0
|
assert (kernel_size - 1) % 2 == 0
|
||||||
|
|
||||||
self.pointwise_conv1 = nn.Conv1d(
|
self.pointwise_conv1 = StructuredConv1d(
|
||||||
channels,
|
(channels,),
|
||||||
2 * channels,
|
(2, channels),
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0,
|
padding=0,
|
||||||
@ -1021,8 +1023,9 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
ActivationBalancer(channel_dim=1),
|
ActivationBalancer(channel_dim=1),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
)
|
)
|
||||||
self.out = nn.Linear(
|
out_height = (((in_channels - 1) // 2 - 1) // 2)
|
||||||
layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels
|
self.out = StructuredLinear(
|
||||||
|
(out_height, layer3_channels), (out_channels,)
|
||||||
)
|
)
|
||||||
# set learn_eps=False because out_norm is preceded by `out`, and `out`
|
# set learn_eps=False because out_norm is preceded by `out`, and `out`
|
||||||
# itself has learned scale, so the extra degree of freedom is not
|
# itself has learned scale, so the extra degree of freedom is not
|
||||||
|
@ -17,12 +17,14 @@
|
|||||||
|
|
||||||
import collections
|
import collections
|
||||||
from itertools import repeat
|
from itertools import repeat
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple, Union
|
||||||
|
from functools import reduce
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import random
|
import random
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import Embedding as ScaledEmbedding
|
from torch.nn import Embedding as ScaledEmbedding
|
||||||
|
|
||||||
@ -155,6 +157,153 @@ class BasicNorm(torch.nn.Module):
|
|||||||
return x * scales
|
return x * scales
|
||||||
|
|
||||||
|
|
||||||
|
class StructuredLinear(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
This module mostly behaves like nn.Linear, but the in_features and out_features
|
||||||
|
(the number of input and output channels) are specified as tuples; the
|
||||||
|
actual numbers of channels are products over these tuples.
|
||||||
|
E.g. (2, 256) means 512, with the slowest-varying/largest-stride dims first
|
||||||
|
in terms of the layout.
|
||||||
|
For purposes of the forward() function it will behave the same as if the dim
|
||||||
|
was 512, but the parameter tensors have this structure, which makes
|
||||||
|
a difference if you are using the NeutralGradient optimizer and perhaps
|
||||||
|
certain other optimizers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_features: The number of input channels, specified as
|
||||||
|
a tuple of ints (the number of input channels will be their
|
||||||
|
product). The only difference this makes is that the
|
||||||
|
nn.Parameter tensor will be shaped differently, which may
|
||||||
|
affect some optimizers.
|
||||||
|
out_features: The number of output channels, specified as
|
||||||
|
a tuple of ints.
|
||||||
|
initial_scale: The default initial parameter scale will be
|
||||||
|
multiplied by this.
|
||||||
|
bias: If true, include the bias term.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
in_features: Tuple[int],
|
||||||
|
out_features: Tuple[int],
|
||||||
|
bias: bool = True,
|
||||||
|
initial_scale: float = 1.0) -> None:
|
||||||
|
super(StructuredLinear, self).__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
in_size = reduce((lambda i,j: i*j), in_features)
|
||||||
|
out_size = reduce((lambda i,j: i*j), out_features)
|
||||||
|
self.weight_shape = (out_size, in_size)
|
||||||
|
self.weight = nn.Parameter(torch.Tensor(*out_features, *in_features))
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.bias = nn.Parameter(torch.Tensor(*out_features))
|
||||||
|
else:
|
||||||
|
self.register_parameter('bias', None)
|
||||||
|
self.reset_parameters(initial_scale)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_parameters(self, initial_scale: float = 1.0) -> None:
|
||||||
|
nn.init.kaiming_uniform_(self.weight.reshape(*self.weight_shape), a=(5 ** 0.5))
|
||||||
|
with torch.no_grad():
|
||||||
|
self.weight *= initial_scale
|
||||||
|
nn.init.uniform_(self.bias,
|
||||||
|
-0.1 * initial_scale,
|
||||||
|
0.1 * initial_scale)
|
||||||
|
|
||||||
|
def get_weight(self) -> Tensor:
|
||||||
|
return self.weight.reshape(*self.weight_shape)
|
||||||
|
|
||||||
|
def get_bias(self) -> Optional[Tensor]:
|
||||||
|
return (None if self.bias is None else
|
||||||
|
self.bias.reshape(self.weight_shape[0]))
|
||||||
|
|
||||||
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
|
return F.linear(input, self.get_weight(), self.get_bias())
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return 'in_features={}, out_features={}, bias={}'.format(
|
||||||
|
self.in_features, self.out_features, self.bias is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StructuredConv1d(nn.Conv1d):
|
||||||
|
"""
|
||||||
|
This module mostly behaves like nn.Conv1d, but the
|
||||||
|
in_channels and out_channels are specified as tuples. For example,
|
||||||
|
512 channels might be specified as
|
||||||
|
(2, 256), with slowest-varying/largest-stride dims first in terms of the layout.
|
||||||
|
For purposes of the forward() function it will behave the same as if the dim
|
||||||
|
was 512, but the parameter tensors have this structure, which makes
|
||||||
|
a difference if you are using the NeutralGradient optimizer.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels: The number of input channels, specified as
|
||||||
|
a tuple of ints (the number of input channels will be their
|
||||||
|
product). The only difference this makes is that the
|
||||||
|
nn.Parameter tensor will be shaped differently, which may
|
||||||
|
affect some optimizers.
|
||||||
|
out_channels: The number of output channels, specified as
|
||||||
|
a tuple of ints.
|
||||||
|
initial_scale: The default initial parameter scale will be
|
||||||
|
multiplied by this.
|
||||||
|
bias: If true, include the bias term.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: Tuple[int],
|
||||||
|
out_channels: Tuple[int],
|
||||||
|
*args,
|
||||||
|
initial_scale: float = 1.0,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super(StructuredConv1d, self).__init__(
|
||||||
|
reduce((lambda i,j: i*j), in_channels),
|
||||||
|
reduce((lambda i,j: i*j), out_channels),
|
||||||
|
*args, **kwargs)
|
||||||
|
|
||||||
|
assert self.groups == 1, "Groups not supported as yet"
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
if self.transposed:
|
||||||
|
in_channels, out_channels = out_channels, in_channels
|
||||||
|
|
||||||
|
self.weight_shape = self.weight.shape
|
||||||
|
self.weight = nn.Parameter(self.weight.detach().reshape(
|
||||||
|
*out_channels, *in_channels, *self.weight.shape[2:]))
|
||||||
|
|
||||||
|
self.bias_shape = self.bias.shape
|
||||||
|
if self.bias is not None:
|
||||||
|
self.bias = nn.Parameter(self.bias.detach().reshape(
|
||||||
|
*out_channels))
|
||||||
|
|
||||||
|
# These changes in the initialization are the same as for class ScaledConv1d.
|
||||||
|
with torch.no_grad():
|
||||||
|
self.weight[:] *= initial_scale
|
||||||
|
if self.bias is not None:
|
||||||
|
torch.nn.init.uniform_(self.bias,
|
||||||
|
-0.1 * initial_scale,
|
||||||
|
0.1 * initial_scale)
|
||||||
|
|
||||||
|
def get_weight(self) -> Tensor:
|
||||||
|
return self.weight.reshape(*self.weight_shape)
|
||||||
|
def get_bias(self) -> Optional[Tensor]:
|
||||||
|
return (None if self.bias is None else
|
||||||
|
self.bias.reshape(*self.bias_shape))
|
||||||
|
|
||||||
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
|
if self.padding_mode != 'zeros':
|
||||||
|
return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
|
||||||
|
self.get_weight(), self.get_bias(), self.stride,
|
||||||
|
_single(0), self.dilation, self.groups)
|
||||||
|
return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride,
|
||||||
|
self.padding, self.dilation, self.groups)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ScaledLinear(nn.Linear):
|
class ScaledLinear(nn.Linear):
|
||||||
"""
|
"""
|
||||||
A modified version of nn.Linear that gives an easy way to set the
|
A modified version of nn.Linear that gives an easy way to set the
|
||||||
@ -484,6 +633,22 @@ def _test_double_swish_deriv():
|
|||||||
m = DoubleSwish()
|
m = DoubleSwish()
|
||||||
torch.autograd.gradcheck(m, x)
|
torch.autograd.gradcheck(m, x)
|
||||||
|
|
||||||
|
def _test_structured_linear():
|
||||||
|
m = StructuredLinear((2, 100), (3, 100), bias=True)
|
||||||
|
assert m.weight.shape == (3, 100, 2, 100)
|
||||||
|
assert m.bias.shape == (3, 100)
|
||||||
|
x = torch.randn(50, 200)
|
||||||
|
y = m(x)
|
||||||
|
assert y.shape == (50, 300)
|
||||||
|
|
||||||
|
def _test_structured_conv1d():
|
||||||
|
m = StructuredConv1d((2, 100), (3, 100), kernel_size=3, padding=1, bias=True)
|
||||||
|
assert m.weight.shape == (3, 100, 2, 100, 3)
|
||||||
|
assert m.bias.shape == (3, 100)
|
||||||
|
T = 39
|
||||||
|
x = torch.randn(50, 200, T)
|
||||||
|
y = m(x)
|
||||||
|
assert y.shape == (50, 300, T)
|
||||||
|
|
||||||
def _test_gauss_proj_drop():
|
def _test_gauss_proj_drop():
|
||||||
D = 384
|
D = 384
|
||||||
@ -511,6 +676,8 @@ if __name__ == "__main__":
|
|||||||
logging.getLogger().setLevel(logging.INFO)
|
logging.getLogger().setLevel(logging.INFO)
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
_test_structured_linear()
|
||||||
|
_test_structured_conv1d()
|
||||||
_test_gauss_proj_drop()
|
_test_gauss_proj_drop()
|
||||||
_test_activation_balancer_sign()
|
_test_activation_balancer_sign()
|
||||||
_test_activation_balancer_magnitude()
|
_test_activation_balancer_magnitude()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user