Another rework, use scales on linear/conv

This commit is contained in:
Daniel Povey 2022-03-12 15:38:13 +08:00
parent 0abba9e7a2
commit ca8cf2a73b
2 changed files with 140 additions and 89 deletions

View File

@ -44,20 +44,20 @@ class Conv2dSubsampling(nn.Module):
assert idim >= 7
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(
ScaledConv2d(
in_channels=1, out_channels=odim, kernel_size=3, stride=2
),
DerivBalancer(channel_dim=1, threshold=0.05,
max_factor=0.01),
ExpScaleRelu(odim, 1, 1, speed=20.0),
nn.Conv2d(
ScaledConv2d(
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
),
DerivBalancer(channel_dim=1, threshold=0.05,
max_factor=0.01),
ExpScaleRelu(odim, 1, 1, speed=20.0),
)
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim)
self.out_norm = BasicNorm(odim)
self._reset_parameters()
@ -221,21 +221,18 @@ class ExpScale(torch.nn.Module):
def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float, in_scale: float) -> Tensor:
def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor:
# double-swish, implemented/approximated as offset-swish
if in_scale != 1.0:
x = x * in_scale
x = (x * torch.sigmoid(x - 1.0))
x = x * (scale * speed).exp()
return x
class SwishExpScaleFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor, scale: Tensor, speed: float, in_scale: float) -> Tensor:
def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor:
ctx.save_for_backward(x.detach(), scale.detach())
ctx.speed = speed
ctx.in_scale = in_scale
return _exp_scale_swish(x, scale, speed, in_scale)
return _exp_scale_swish(x, scale, speed)
@staticmethod
def backward(ctx, y_grad: Tensor) -> Tensor:
@ -243,25 +240,24 @@ class SwishExpScaleFunction(torch.autograd.Function):
x.requires_grad = True
scale.requires_grad = True
with torch.enable_grad():
y = _exp_scale_swish(x, scale, ctx.speed, ctx.in_scale)
y = _exp_scale_swish(x, scale, ctx.speed)
y.backward(gradient=y_grad)
return x.grad, scale.grad, None, None
return x.grad, scale.grad, None
class SwishExpScale(torch.nn.Module):
# combines ExpScale and a Swish (actually the ExpScale is after the Swish).
# caution: need to specify name for speed, e.g. SwishExpScale(50, speed=4.0)
#
def __init__(self, *shape, speed: float = 1.0, in_scale: float = 1.0):
def __init__(self, *shape, speed: float = 1.0):
super(SwishExpScale, self).__init__()
self.in_scale = in_scale
initial_log_scale = torch.tensor(1.0 / in_scale).log() / speed
initial_log_scale = (torch.ones(*shape) * initial_log_scale).detach()
initial_log_scale = torch.zeros(()).detach()
self.scale = nn.Parameter(initial_log_scale)
self.speed = speed
def forward(self, x: Tensor) -> Tensor:
return SwishExpScaleFunction.apply(x, self.scale, self.speed, self.in_scale)
return SwishExpScaleFunction.apply(x, self.scale, self.speed)
# x = (x * torch.sigmoid(x))
# x = (x * torch.sigmoid(x))
# x = x * (self.scale * self.speed).exp()
@ -383,12 +379,11 @@ class BasicNorm(torch.nn.Module):
interprted as an offset from the input's ndim if negative.
shis is NOT the num_channels; it should typically be one of
{-2, -1, 0, 1, 2, 3}.
initial_eps_scale: a constant that determines the initial
"epsilon" that we add as ballast in:
scale = output_scale * ((input_vec**2).sum() + epsilon)**-0.5
Note: our epsilon is actually large, not small, but we keep the name
to indicate the connection with normal LayerNorm. We set
epsilon initially to num_channels * initial_eps_scale.
initial_eps: the initial "epsilon" that we add as ballast in:
scale = ((input_vec**2).mean() + epsilon)**-0.5
Note: our epsilon is actually large, but we keep the name
to indicate the connection with normal LayerNorm.
speed: a scaling factor that can be interpreted as scaling the learning
rate for this module. CAUTION: the default value of 10.0 intended to be
used with Adam or amsgrad-type optimizers, e.g. Adam or Noam.
@ -398,42 +393,101 @@ class BasicNorm(torch.nn.Module):
def __init__(self,
num_channels: int,
channel_dim: int = -1, # CAUTION: see documentation.
initial_eps_scale: float = 0.25,
speed: float = 10.0):
eps: float = 0.25):
super(BasicNorm, self).__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim
self.speed = speed
eps = num_channels * initial_eps_scale
# log_eps = log(eps) / speed
log_eps = torch.tensor(eps).log() / speed
self.log_eps = nn.Parameter(log_eps.detach())
# initial output-scale, to get LayerNorm-like behavior, is
# sqrt(num_channels).
initial_scale = torch.tensor(num_channels ** 0.5).log() / speed
self.log_scale = nn.Parameter(initial_scale.detach())
def _inner(self, x: Tensor) -> Tensor:
# inner product on last dim of x, keeping the dimension,
# i.e. torch.sum(x**2, dim=-1, keepdim=True), but more
# efficient.
if hasattr(torch, 'inner'):
return torch.inner(x).unsqueeze(-1)
else:
# TODO: we can do this with matrix multiplication, maybe.a
return torch.sum(x**2, dim=-1, keepdim=True)
self.eps = eps
def forward(self, x: Tensor) -> Tensor:
assert x.shape[self.channel_dim] == self.num_channels
x = x.transpose(-1, self.channel_dim)
eps = (self.log_eps * self.speed).exp()
out_scale = (self.log_scale * self.speed).exp()
scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps) ** -0.5
return x * scales
class ScaledLinear(nn.Linear):
def __init__(self, *args, scale_speed=5.0, **kwargs):
super(ScaledLinear, self).__init__(*args, **kwargs)
self.weight_scale = nn.Parameter(torch.zeros(()))
self.scale_speed = scale_speed
if self.bias is not None:
self.bias_scale = nn.Parameter(torch.zeros(()))
else:
self.register_parameter('bias_scale', None)
def get_weight(self):
return self.weight * (self.weight_scale * self.scale_speed).exp()
def get_bias(self):
return (None if self.bias is None else
self.bias * (self.bias_scale * self.scale_speed).exp())
def forward(self, input: Tensor) -> Tensor:
return torch.nn.functional.linear(input, self.get_weight(),
self.get_bias())
class ScaledConv1d(nn.Conv1d):
def __init__(self, *args, scale_speed = 5.0, **kwargs):
super(ScaledConv1d, self).__init__(*args, **kwargs)
self.scale_speed = scale_speed
self.weight_scale = nn.Parameter(torch.zeros(()))
if self.bias is not None:
self.bias_scale = nn.Parameter(torch.zeros(()))
else:
self.register_parameter('bias_scale', None)
def get_weight(self):
return self.weight * (self.weight_scale * self.scale_speed).exp()
def get_bias(self):
return (None if self.bias is None else
self.bias * (self.bias_scale * self.scale_speed).exp())
def forward(self, input: Tensor) -> Tensor:
F = torch.nn.functional
if self.padding_mode != 'zeros':
return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
self.get_weight(), self.get_bias(), self.stride,
_single(0), self.dilation, self.groups)
return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride,
self.padding, self.dilation, self.groups)
class ScaledConv2d(nn.Conv2d):
def __init__(self, *args, scale_speed=5.0, **kwargs):
super(ScaledConv2d, self).__init__(*args, **kwargs)
self.scale_speed = scale_speed
self.weight_scale = nn.Parameter(torch.zeros(()))
if self.bias is not None:
self.bias_scale = nn.Parameter(torch.zeros(()))
else:
self.register_parameter('bias_scale', None)
def get_weight(self):
return self.weight * (self.weight_scale * self.scale_speed).exp()
def get_bias(self):
return (None if self.bias is None else
self.bias * (self.bias_scale * self.scale_speed).exp())
def _conv_forward(self, input, weight):
F = torch.nn.functional
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight, self.get_bias(), self.stride,
_pair(0), self.dilation, self.groups)
return F.conv2d(input, weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
def forward(self, input: Tensor) -> Tensor:
return self._conv_forward(input, self.get_weight())
scales = out_scale * (self._inner(x) + eps) ** -0.5
x = x * scales
x = x.transpose(-1, self.channel_dim)
return x
@ -576,6 +630,8 @@ def _test_basic_norm():
if __name__ == '__main__':
_test_deriv_balancer_sign()
_test_deriv_balancer_magnitude()

View File

@ -19,7 +19,7 @@ import copy
import math
import warnings
from typing import Optional, Tuple, Sequence
from subsampling import PeLU, ExpScale, SwishExpScale, ExpScaleRelu, DerivBalancer, BasicNorm
from subsampling import PeLU, ExpScale, SwishExpScale, ExpScaleRelu, DerivBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d
import torch
from torch import Tensor, nn
@ -157,30 +157,25 @@ class ConformerEncoderLayer(nn.Module):
)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
ScaledLinear(d_model, dim_feedforward),
DerivBalancer(channel_dim=-1, threshold=0.05,
max_factor=0.01),
SwishExpScale(dim_feedforward, speed=20.0, in_scale=2.0),
SwishExpScale(dim_feedforward, speed=20.0),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
ScaledLinear(dim_feedforward, d_model),
)
self.feed_forward_macaron = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
ScaledLinear(d_model, dim_feedforward),
DerivBalancer(channel_dim=-1, threshold=0.05,
max_factor=0.01),
SwishExpScale(dim_feedforward, speed=20.0, in_scale=2.0),
SwishExpScale(dim_feedforward, speed=20.0),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
ScaledLinear(dim_feedforward, d_model),
)
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.scale_mha = ExpScale(1, speed=10.0, initial_scale=0.2)
self.post_scale_mha = ExpScale(1, speed=10.0, initial_scale=1.0)
self.scale_conv = ExpScale(1, speed=10.0, initial_scale=0.5)
self.scale_ff = ExpScale(1, speed=10.0, initial_scale=0.5)
self.scale_ff_macaron = ExpScale(1, speed=10.0, initial_scale=0.5)
self.pre_norm_final = Identity()
self.norm_final = BasicNorm(d_model)
@ -216,13 +211,10 @@ class ConformerEncoderLayer(nn.Module):
residual = src
src = src + self.dropout(self.feed_forward_macaron(
self.scale_ff_macaron(src)))
src = src + self.dropout(self.feed_forward_macaron(src))
# multi-headed self-attention module
residual = src
src = self.scale_mha(src)
src_att = self.self_attn(
src,
src,
@ -231,13 +223,13 @@ class ConformerEncoderLayer(nn.Module):
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)[0]
src = residual + self.post_scale_mha(self.dropout(src_att))
src = src + self.dropout(src_att)
# convolution module
src = src + self.dropout(self.conv_module(self.scale_conv(src)))
src = src + self.dropout(self.conv_module(src))
# feed forward module
src = src + self.dropout(self.feed_forward(self.scale_ff(src)))
src = src + self.dropout(self.feed_forward(src))
src = self.norm_final(self.pre_norm_final(src))
@ -420,6 +412,7 @@ class RelPositionMultiheadAttention(nn.Module):
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
scale_speed: float = 5.0
) -> None:
super(RelPositionMultiheadAttention, self).__init__()
self.embed_dim = embed_dim
@ -430,18 +423,27 @@ class RelPositionMultiheadAttention(nn.Module):
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True)
# linear transformation for positional encoding.
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.scale_speed = scale_speed
self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach())
self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach())
self._reset_parameters()
def _pos_bias_u(self):
return self.pos_bias_u * (self.pos_bias_u_scale * self.scale_speed).exp()
def _pos_bias_v(self):
return self.pos_bias_v * (self.pos_bias_v_scale * self.scale_speed).exp()
def _reset_parameters(self) -> None:
nn.init.xavier_uniform_(self.in_proj.weight)
nn.init.constant_(self.in_proj.bias, 0.0)
@ -508,11 +510,11 @@ class RelPositionMultiheadAttention(nn.Module):
pos_emb,
self.embed_dim,
self.num_heads,
self.in_proj.weight,
self.in_proj.bias,
self.in_proj.get_weight(),
self.in_proj.get_bias(),
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
self.out_proj.get_weight(),
self.out_proj.get_bias(),
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
@ -743,11 +745,11 @@ class RelPositionMultiheadAttention(nn.Module):
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(
q_with_bias_u = (q + self._pos_bias_u()).transpose(
1, 2
) # (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(
q_with_bias_v = (q + self._pos_bias_v()).transpose(
1, 2
) # (batch, head, time1, d_k)
@ -842,7 +844,7 @@ class ConvolutionModule(nn.Module):
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.pointwise_conv1 = nn.Conv1d(
self.pointwise_conv1 = ScaledConv1d(
channels,
2 * channels,
kernel_size=1,
@ -850,7 +852,7 @@ class ConvolutionModule(nn.Module):
padding=0,
bias=bias,
)
self.depthwise_conv = nn.Conv1d(
self.depthwise_conv = ScaledConv1d(
channels,
channels,
kernel_size,
@ -860,12 +862,10 @@ class ConvolutionModule(nn.Module):
bias=bias,
)
self.scale = ExpScale(1, speed=10.0, initial_scale=1.0)
# shape: (channels, 1), broadcasts with (batch, channel, time).
self.activation = SwishOffset()
self.pointwise_conv2 = nn.Conv1d(
self.pointwise_conv2 = ScaledConv1d(
channels,
channels,
kernel_size=1,
@ -897,11 +897,6 @@ class ConvolutionModule(nn.Module):
# TODO: can have a learned scale in here, or a fixed one.
x = self.activation(x)
# x is (batch, channels, time)
x = x.permute(0, 2, 1)
x = self.scale(x)
x = x.permute(0, 2, 1)
x = self.pointwise_conv2(x) # (batch, channel, time)
return x.permute(2, 0, 1)
@ -982,7 +977,7 @@ class RandomCombine(torch.nn.Module):
assert pure_prob >= 0 and pure_prob <= 1
assert final_weight > 0 and final_weight < 1
assert num_inputs >= 1
self.linear = nn.ModuleList([nn.Linear(num_channels, num_channels, bias=True)
self.linear = nn.ModuleList([ScaledLinear(num_channels, num_channels, bias=True)
for _ in range(num_inputs - 1)])
self.num_inputs = num_inputs