mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Another rework, use scales on linear/conv
This commit is contained in:
parent
0abba9e7a2
commit
ca8cf2a73b
@ -44,20 +44,20 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
assert idim >= 7
|
assert idim >= 7
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = nn.Sequential(
|
self.conv = nn.Sequential(
|
||||||
nn.Conv2d(
|
ScaledConv2d(
|
||||||
in_channels=1, out_channels=odim, kernel_size=3, stride=2
|
in_channels=1, out_channels=odim, kernel_size=3, stride=2
|
||||||
),
|
),
|
||||||
DerivBalancer(channel_dim=1, threshold=0.05,
|
DerivBalancer(channel_dim=1, threshold=0.05,
|
||||||
max_factor=0.01),
|
max_factor=0.01),
|
||||||
ExpScaleRelu(odim, 1, 1, speed=20.0),
|
ExpScaleRelu(odim, 1, 1, speed=20.0),
|
||||||
nn.Conv2d(
|
ScaledConv2d(
|
||||||
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
|
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
|
||||||
),
|
),
|
||||||
DerivBalancer(channel_dim=1, threshold=0.05,
|
DerivBalancer(channel_dim=1, threshold=0.05,
|
||||||
max_factor=0.01),
|
max_factor=0.01),
|
||||||
ExpScaleRelu(odim, 1, 1, speed=20.0),
|
ExpScaleRelu(odim, 1, 1, speed=20.0),
|
||||||
)
|
)
|
||||||
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
||||||
self.out_norm = BasicNorm(odim)
|
self.out_norm = BasicNorm(odim)
|
||||||
self._reset_parameters()
|
self._reset_parameters()
|
||||||
|
|
||||||
@ -221,21 +221,18 @@ class ExpScale(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float, in_scale: float) -> Tensor:
|
def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor:
|
||||||
# double-swish, implemented/approximated as offset-swish
|
# double-swish, implemented/approximated as offset-swish
|
||||||
if in_scale != 1.0:
|
|
||||||
x = x * in_scale
|
|
||||||
x = (x * torch.sigmoid(x - 1.0))
|
x = (x * torch.sigmoid(x - 1.0))
|
||||||
x = x * (scale * speed).exp()
|
x = x * (scale * speed).exp()
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class SwishExpScaleFunction(torch.autograd.Function):
|
class SwishExpScaleFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x: Tensor, scale: Tensor, speed: float, in_scale: float) -> Tensor:
|
def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor:
|
||||||
ctx.save_for_backward(x.detach(), scale.detach())
|
ctx.save_for_backward(x.detach(), scale.detach())
|
||||||
ctx.speed = speed
|
ctx.speed = speed
|
||||||
ctx.in_scale = in_scale
|
return _exp_scale_swish(x, scale, speed)
|
||||||
return _exp_scale_swish(x, scale, speed, in_scale)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, y_grad: Tensor) -> Tensor:
|
def backward(ctx, y_grad: Tensor) -> Tensor:
|
||||||
@ -243,25 +240,24 @@ class SwishExpScaleFunction(torch.autograd.Function):
|
|||||||
x.requires_grad = True
|
x.requires_grad = True
|
||||||
scale.requires_grad = True
|
scale.requires_grad = True
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
y = _exp_scale_swish(x, scale, ctx.speed, ctx.in_scale)
|
y = _exp_scale_swish(x, scale, ctx.speed)
|
||||||
y.backward(gradient=y_grad)
|
y.backward(gradient=y_grad)
|
||||||
return x.grad, scale.grad, None, None
|
return x.grad, scale.grad, None
|
||||||
|
|
||||||
|
|
||||||
class SwishExpScale(torch.nn.Module):
|
class SwishExpScale(torch.nn.Module):
|
||||||
# combines ExpScale and a Swish (actually the ExpScale is after the Swish).
|
# combines ExpScale and a Swish (actually the ExpScale is after the Swish).
|
||||||
# caution: need to specify name for speed, e.g. SwishExpScale(50, speed=4.0)
|
# caution: need to specify name for speed, e.g. SwishExpScale(50, speed=4.0)
|
||||||
#
|
#
|
||||||
def __init__(self, *shape, speed: float = 1.0, in_scale: float = 1.0):
|
def __init__(self, *shape, speed: float = 1.0):
|
||||||
super(SwishExpScale, self).__init__()
|
super(SwishExpScale, self).__init__()
|
||||||
self.in_scale = in_scale
|
|
||||||
initial_log_scale = torch.tensor(1.0 / in_scale).log() / speed
|
initial_log_scale = torch.zeros(()).detach()
|
||||||
initial_log_scale = (torch.ones(*shape) * initial_log_scale).detach()
|
|
||||||
self.scale = nn.Parameter(initial_log_scale)
|
self.scale = nn.Parameter(initial_log_scale)
|
||||||
self.speed = speed
|
self.speed = speed
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return SwishExpScaleFunction.apply(x, self.scale, self.speed, self.in_scale)
|
return SwishExpScaleFunction.apply(x, self.scale, self.speed)
|
||||||
# x = (x * torch.sigmoid(x))
|
# x = (x * torch.sigmoid(x))
|
||||||
# x = (x * torch.sigmoid(x))
|
# x = (x * torch.sigmoid(x))
|
||||||
# x = x * (self.scale * self.speed).exp()
|
# x = x * (self.scale * self.speed).exp()
|
||||||
@ -383,12 +379,11 @@ class BasicNorm(torch.nn.Module):
|
|||||||
interprted as an offset from the input's ndim if negative.
|
interprted as an offset from the input's ndim if negative.
|
||||||
shis is NOT the num_channels; it should typically be one of
|
shis is NOT the num_channels; it should typically be one of
|
||||||
{-2, -1, 0, 1, 2, 3}.
|
{-2, -1, 0, 1, 2, 3}.
|
||||||
initial_eps_scale: a constant that determines the initial
|
initial_eps: the initial "epsilon" that we add as ballast in:
|
||||||
"epsilon" that we add as ballast in:
|
scale = ((input_vec**2).mean() + epsilon)**-0.5
|
||||||
scale = output_scale * ((input_vec**2).sum() + epsilon)**-0.5
|
Note: our epsilon is actually large, but we keep the name
|
||||||
Note: our epsilon is actually large, not small, but we keep the name
|
to indicate the connection with normal LayerNorm.
|
||||||
to indicate the connection with normal LayerNorm. We set
|
|
||||||
epsilon initially to num_channels * initial_eps_scale.
|
|
||||||
speed: a scaling factor that can be interpreted as scaling the learning
|
speed: a scaling factor that can be interpreted as scaling the learning
|
||||||
rate for this module. CAUTION: the default value of 10.0 intended to be
|
rate for this module. CAUTION: the default value of 10.0 intended to be
|
||||||
used with Adam or amsgrad-type optimizers, e.g. Adam or Noam.
|
used with Adam or amsgrad-type optimizers, e.g. Adam or Noam.
|
||||||
@ -398,42 +393,101 @@ class BasicNorm(torch.nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_channels: int,
|
num_channels: int,
|
||||||
channel_dim: int = -1, # CAUTION: see documentation.
|
channel_dim: int = -1, # CAUTION: see documentation.
|
||||||
initial_eps_scale: float = 0.25,
|
eps: float = 0.25):
|
||||||
speed: float = 10.0):
|
|
||||||
super(BasicNorm, self).__init__()
|
super(BasicNorm, self).__init__()
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
self.channel_dim = channel_dim
|
self.channel_dim = channel_dim
|
||||||
self.speed = speed
|
self.eps = eps
|
||||||
eps = num_channels * initial_eps_scale
|
|
||||||
# log_eps = log(eps) / speed
|
|
||||||
log_eps = torch.tensor(eps).log() / speed
|
|
||||||
self.log_eps = nn.Parameter(log_eps.detach())
|
|
||||||
# initial output-scale, to get LayerNorm-like behavior, is
|
|
||||||
# sqrt(num_channels).
|
|
||||||
initial_scale = torch.tensor(num_channels ** 0.5).log() / speed
|
|
||||||
self.log_scale = nn.Parameter(initial_scale.detach())
|
|
||||||
|
|
||||||
def _inner(self, x: Tensor) -> Tensor:
|
|
||||||
# inner product on last dim of x, keeping the dimension,
|
|
||||||
# i.e. torch.sum(x**2, dim=-1, keepdim=True), but more
|
|
||||||
# efficient.
|
|
||||||
if hasattr(torch, 'inner'):
|
|
||||||
return torch.inner(x).unsqueeze(-1)
|
|
||||||
else:
|
|
||||||
# TODO: we can do this with matrix multiplication, maybe.a
|
|
||||||
return torch.sum(x**2, dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
assert x.shape[self.channel_dim] == self.num_channels
|
assert x.shape[self.channel_dim] == self.num_channels
|
||||||
x = x.transpose(-1, self.channel_dim)
|
scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps) ** -0.5
|
||||||
eps = (self.log_eps * self.speed).exp()
|
return x * scales
|
||||||
out_scale = (self.log_scale * self.speed).exp()
|
|
||||||
|
|
||||||
|
class ScaledLinear(nn.Linear):
|
||||||
|
def __init__(self, *args, scale_speed=5.0, **kwargs):
|
||||||
|
super(ScaledLinear, self).__init__(*args, **kwargs)
|
||||||
|
self.weight_scale = nn.Parameter(torch.zeros(()))
|
||||||
|
self.scale_speed = scale_speed
|
||||||
|
if self.bias is not None:
|
||||||
|
self.bias_scale = nn.Parameter(torch.zeros(()))
|
||||||
|
else:
|
||||||
|
self.register_parameter('bias_scale', None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_weight(self):
|
||||||
|
return self.weight * (self.weight_scale * self.scale_speed).exp()
|
||||||
|
|
||||||
|
def get_bias(self):
|
||||||
|
return (None if self.bias is None else
|
||||||
|
self.bias * (self.bias_scale * self.scale_speed).exp())
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
|
return torch.nn.functional.linear(input, self.get_weight(),
|
||||||
|
self.get_bias())
|
||||||
|
|
||||||
|
|
||||||
|
class ScaledConv1d(nn.Conv1d):
|
||||||
|
def __init__(self, *args, scale_speed = 5.0, **kwargs):
|
||||||
|
super(ScaledConv1d, self).__init__(*args, **kwargs)
|
||||||
|
self.scale_speed = scale_speed
|
||||||
|
self.weight_scale = nn.Parameter(torch.zeros(()))
|
||||||
|
if self.bias is not None:
|
||||||
|
self.bias_scale = nn.Parameter(torch.zeros(()))
|
||||||
|
else:
|
||||||
|
self.register_parameter('bias_scale', None)
|
||||||
|
|
||||||
|
def get_weight(self):
|
||||||
|
return self.weight * (self.weight_scale * self.scale_speed).exp()
|
||||||
|
|
||||||
|
def get_bias(self):
|
||||||
|
return (None if self.bias is None else
|
||||||
|
self.bias * (self.bias_scale * self.scale_speed).exp())
|
||||||
|
|
||||||
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
|
F = torch.nn.functional
|
||||||
|
if self.padding_mode != 'zeros':
|
||||||
|
return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
|
||||||
|
self.get_weight(), self.get_bias(), self.stride,
|
||||||
|
_single(0), self.dilation, self.groups)
|
||||||
|
return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride,
|
||||||
|
self.padding, self.dilation, self.groups)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ScaledConv2d(nn.Conv2d):
|
||||||
|
def __init__(self, *args, scale_speed=5.0, **kwargs):
|
||||||
|
super(ScaledConv2d, self).__init__(*args, **kwargs)
|
||||||
|
self.scale_speed = scale_speed
|
||||||
|
self.weight_scale = nn.Parameter(torch.zeros(()))
|
||||||
|
if self.bias is not None:
|
||||||
|
self.bias_scale = nn.Parameter(torch.zeros(()))
|
||||||
|
else:
|
||||||
|
self.register_parameter('bias_scale', None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_weight(self):
|
||||||
|
return self.weight * (self.weight_scale * self.scale_speed).exp()
|
||||||
|
|
||||||
|
def get_bias(self):
|
||||||
|
return (None if self.bias is None else
|
||||||
|
self.bias * (self.bias_scale * self.scale_speed).exp())
|
||||||
|
|
||||||
|
def _conv_forward(self, input, weight):
|
||||||
|
F = torch.nn.functional
|
||||||
|
if self.padding_mode != 'zeros':
|
||||||
|
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
|
||||||
|
weight, self.get_bias(), self.stride,
|
||||||
|
_pair(0), self.dilation, self.groups)
|
||||||
|
return F.conv2d(input, weight, self.bias, self.stride,
|
||||||
|
self.padding, self.dilation, self.groups)
|
||||||
|
|
||||||
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
|
return self._conv_forward(input, self.get_weight())
|
||||||
|
|
||||||
scales = out_scale * (self._inner(x) + eps) ** -0.5
|
|
||||||
x = x * scales
|
|
||||||
x = x.transpose(-1, self.channel_dim)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -576,6 +630,8 @@ def _test_basic_norm():
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
_test_deriv_balancer_sign()
|
_test_deriv_balancer_sign()
|
||||||
_test_deriv_balancer_magnitude()
|
_test_deriv_balancer_magnitude()
|
||||||
|
@ -19,7 +19,7 @@ import copy
|
|||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Tuple, Sequence
|
from typing import Optional, Tuple, Sequence
|
||||||
from subsampling import PeLU, ExpScale, SwishExpScale, ExpScaleRelu, DerivBalancer, BasicNorm
|
from subsampling import PeLU, ExpScale, SwishExpScale, ExpScaleRelu, DerivBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
@ -157,30 +157,25 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.feed_forward = nn.Sequential(
|
self.feed_forward = nn.Sequential(
|
||||||
nn.Linear(d_model, dim_feedforward),
|
ScaledLinear(d_model, dim_feedforward),
|
||||||
DerivBalancer(channel_dim=-1, threshold=0.05,
|
DerivBalancer(channel_dim=-1, threshold=0.05,
|
||||||
max_factor=0.01),
|
max_factor=0.01),
|
||||||
SwishExpScale(dim_feedforward, speed=20.0, in_scale=2.0),
|
SwishExpScale(dim_feedforward, speed=20.0),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
nn.Linear(dim_feedforward, d_model),
|
ScaledLinear(dim_feedforward, d_model),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.feed_forward_macaron = nn.Sequential(
|
self.feed_forward_macaron = nn.Sequential(
|
||||||
nn.Linear(d_model, dim_feedforward),
|
ScaledLinear(d_model, dim_feedforward),
|
||||||
DerivBalancer(channel_dim=-1, threshold=0.05,
|
DerivBalancer(channel_dim=-1, threshold=0.05,
|
||||||
max_factor=0.01),
|
max_factor=0.01),
|
||||||
SwishExpScale(dim_feedforward, speed=20.0, in_scale=2.0),
|
SwishExpScale(dim_feedforward, speed=20.0),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
nn.Linear(dim_feedforward, d_model),
|
ScaledLinear(dim_feedforward, d_model),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||||
|
|
||||||
self.scale_mha = ExpScale(1, speed=10.0, initial_scale=0.2)
|
|
||||||
self.post_scale_mha = ExpScale(1, speed=10.0, initial_scale=1.0)
|
|
||||||
self.scale_conv = ExpScale(1, speed=10.0, initial_scale=0.5)
|
|
||||||
self.scale_ff = ExpScale(1, speed=10.0, initial_scale=0.5)
|
|
||||||
self.scale_ff_macaron = ExpScale(1, speed=10.0, initial_scale=0.5)
|
|
||||||
|
|
||||||
self.pre_norm_final = Identity()
|
self.pre_norm_final = Identity()
|
||||||
self.norm_final = BasicNorm(d_model)
|
self.norm_final = BasicNorm(d_model)
|
||||||
@ -216,13 +211,10 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
residual = src
|
residual = src
|
||||||
|
|
||||||
|
|
||||||
src = src + self.dropout(self.feed_forward_macaron(
|
src = src + self.dropout(self.feed_forward_macaron(src))
|
||||||
self.scale_ff_macaron(src)))
|
|
||||||
|
|
||||||
|
|
||||||
# multi-headed self-attention module
|
# multi-headed self-attention module
|
||||||
residual = src
|
|
||||||
src = self.scale_mha(src)
|
|
||||||
src_att = self.self_attn(
|
src_att = self.self_attn(
|
||||||
src,
|
src,
|
||||||
src,
|
src,
|
||||||
@ -231,13 +223,13 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
attn_mask=src_mask,
|
attn_mask=src_mask,
|
||||||
key_padding_mask=src_key_padding_mask,
|
key_padding_mask=src_key_padding_mask,
|
||||||
)[0]
|
)[0]
|
||||||
src = residual + self.post_scale_mha(self.dropout(src_att))
|
src = src + self.dropout(src_att)
|
||||||
|
|
||||||
# convolution module
|
# convolution module
|
||||||
src = src + self.dropout(self.conv_module(self.scale_conv(src)))
|
src = src + self.dropout(self.conv_module(src))
|
||||||
|
|
||||||
# feed forward module
|
# feed forward module
|
||||||
src = src + self.dropout(self.feed_forward(self.scale_ff(src)))
|
src = src + self.dropout(self.feed_forward(src))
|
||||||
|
|
||||||
src = self.norm_final(self.pre_norm_final(src))
|
src = self.norm_final(self.pre_norm_final(src))
|
||||||
|
|
||||||
@ -420,6 +412,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
|
scale_speed: float = 5.0
|
||||||
) -> None:
|
) -> None:
|
||||||
super(RelPositionMultiheadAttention, self).__init__()
|
super(RelPositionMultiheadAttention, self).__init__()
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
@ -430,18 +423,27 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
self.head_dim * num_heads == self.embed_dim
|
self.head_dim * num_heads == self.embed_dim
|
||||||
), "embed_dim must be divisible by num_heads"
|
), "embed_dim must be divisible by num_heads"
|
||||||
|
|
||||||
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
|
self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
|
||||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True)
|
||||||
|
|
||||||
# linear transformation for positional encoding.
|
# linear transformation for positional encoding.
|
||||||
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
|
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
|
||||||
# these two learnable bias are used in matrix c and matrix d
|
# these two learnable bias are used in matrix c and matrix d
|
||||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||||
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||||
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||||
|
self.scale_speed = scale_speed
|
||||||
|
self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach())
|
||||||
|
self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach())
|
||||||
|
|
||||||
self._reset_parameters()
|
self._reset_parameters()
|
||||||
|
|
||||||
|
def _pos_bias_u(self):
|
||||||
|
return self.pos_bias_u * (self.pos_bias_u_scale * self.scale_speed).exp()
|
||||||
|
|
||||||
|
def _pos_bias_v(self):
|
||||||
|
return self.pos_bias_v * (self.pos_bias_v_scale * self.scale_speed).exp()
|
||||||
|
|
||||||
def _reset_parameters(self) -> None:
|
def _reset_parameters(self) -> None:
|
||||||
nn.init.xavier_uniform_(self.in_proj.weight)
|
nn.init.xavier_uniform_(self.in_proj.weight)
|
||||||
nn.init.constant_(self.in_proj.bias, 0.0)
|
nn.init.constant_(self.in_proj.bias, 0.0)
|
||||||
@ -508,11 +510,11 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
pos_emb,
|
pos_emb,
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.in_proj.weight,
|
self.in_proj.get_weight(),
|
||||||
self.in_proj.bias,
|
self.in_proj.get_bias(),
|
||||||
self.dropout,
|
self.dropout,
|
||||||
self.out_proj.weight,
|
self.out_proj.get_weight(),
|
||||||
self.out_proj.bias,
|
self.out_proj.get_bias(),
|
||||||
training=self.training,
|
training=self.training,
|
||||||
key_padding_mask=key_padding_mask,
|
key_padding_mask=key_padding_mask,
|
||||||
need_weights=need_weights,
|
need_weights=need_weights,
|
||||||
@ -743,11 +745,11 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
||||||
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
||||||
|
|
||||||
q_with_bias_u = (q + self.pos_bias_u).transpose(
|
q_with_bias_u = (q + self._pos_bias_u()).transpose(
|
||||||
1, 2
|
1, 2
|
||||||
) # (batch, head, time1, d_k)
|
) # (batch, head, time1, d_k)
|
||||||
|
|
||||||
q_with_bias_v = (q + self.pos_bias_v).transpose(
|
q_with_bias_v = (q + self._pos_bias_v()).transpose(
|
||||||
1, 2
|
1, 2
|
||||||
) # (batch, head, time1, d_k)
|
) # (batch, head, time1, d_k)
|
||||||
|
|
||||||
@ -842,7 +844,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
# kernerl_size should be a odd number for 'SAME' padding
|
# kernerl_size should be a odd number for 'SAME' padding
|
||||||
assert (kernel_size - 1) % 2 == 0
|
assert (kernel_size - 1) % 2 == 0
|
||||||
|
|
||||||
self.pointwise_conv1 = nn.Conv1d(
|
self.pointwise_conv1 = ScaledConv1d(
|
||||||
channels,
|
channels,
|
||||||
2 * channels,
|
2 * channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
@ -850,7 +852,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
padding=0,
|
padding=0,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
self.depthwise_conv = nn.Conv1d(
|
self.depthwise_conv = ScaledConv1d(
|
||||||
channels,
|
channels,
|
||||||
channels,
|
channels,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
@ -860,12 +862,10 @@ class ConvolutionModule(nn.Module):
|
|||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.scale = ExpScale(1, speed=10.0, initial_scale=1.0)
|
|
||||||
|
|
||||||
# shape: (channels, 1), broadcasts with (batch, channel, time).
|
# shape: (channels, 1), broadcasts with (batch, channel, time).
|
||||||
self.activation = SwishOffset()
|
self.activation = SwishOffset()
|
||||||
|
|
||||||
self.pointwise_conv2 = nn.Conv1d(
|
self.pointwise_conv2 = ScaledConv1d(
|
||||||
channels,
|
channels,
|
||||||
channels,
|
channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
@ -897,11 +897,6 @@ class ConvolutionModule(nn.Module):
|
|||||||
# TODO: can have a learned scale in here, or a fixed one.
|
# TODO: can have a learned scale in here, or a fixed one.
|
||||||
x = self.activation(x)
|
x = self.activation(x)
|
||||||
|
|
||||||
# x is (batch, channels, time)
|
|
||||||
x = x.permute(0, 2, 1)
|
|
||||||
x = self.scale(x)
|
|
||||||
x = x.permute(0, 2, 1)
|
|
||||||
|
|
||||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||||
|
|
||||||
return x.permute(2, 0, 1)
|
return x.permute(2, 0, 1)
|
||||||
@ -982,7 +977,7 @@ class RandomCombine(torch.nn.Module):
|
|||||||
assert pure_prob >= 0 and pure_prob <= 1
|
assert pure_prob >= 0 and pure_prob <= 1
|
||||||
assert final_weight > 0 and final_weight < 1
|
assert final_weight > 0 and final_weight < 1
|
||||||
assert num_inputs >= 1
|
assert num_inputs >= 1
|
||||||
self.linear = nn.ModuleList([nn.Linear(num_channels, num_channels, bias=True)
|
self.linear = nn.ModuleList([ScaledLinear(num_channels, num_channels, bias=True)
|
||||||
for _ in range(num_inputs - 1)])
|
for _ in range(num_inputs - 1)])
|
||||||
|
|
||||||
self.num_inputs = num_inputs
|
self.num_inputs = num_inputs
|
||||||
|
Loading…
x
Reference in New Issue
Block a user