Update the modified attention codes

This commit is contained in:
Mingshuang Luo 2021-12-13 15:15:15 +08:00
parent e442369987
commit 4392da7235
2 changed files with 222 additions and 55 deletions

View File

@ -17,11 +17,11 @@
import math
import warnings
from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from conv1d_abs_attention import Conv1dAbs
from transformer import Supervisions, Transformer, encoder_padding_mask
@ -178,8 +178,8 @@ class ConformerEncoderLayer(nn.Module):
d_model
) # for the macaron style FNN module
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
#define layernorm for conv1d_abs
# define layernorm for conv1d_abs
self.norm_conv_abs = nn.LayerNorm(d_model)
self.ff_scale = 0.5
@ -194,7 +194,8 @@ class ConformerEncoderLayer(nn.Module):
self.normalize_before = normalize_before
self.linear1 = nn.Linear(512, 1024)
self.conv1d_abs = ConvolutionModule_abs(1024, 64, kernel_size=21, padding=10)
self.conv1d_abs = Conv1dAbs(1024, 64, kernel_size=21, padding=10)
self.activation = nn.ReLU()
self.linear2 = nn.Linear(64, 512)
def forward(
@ -231,13 +232,15 @@ class ConformerEncoderLayer(nn.Module):
if not self.normalize_before:
src = self.norm_ff_macaron(src)
# modified-attention module
# modified-attention module
residual = src
if self.normalize_before:
src = self.norm_conv_abs(src)
src = self.linear1(src)
src = torch.exp(src.clamp(max=75))
src = src.permute(1, 2, 0)
src = self.conv1d_abs(src)
src = self.activation(src).permute(2, 0, 1)
src = torch.log(src)
src = self.linear2(src)
src = residual + self.dropout(src)
@ -396,7 +399,6 @@ class RelPositionalEncoding(torch.nn.Module):
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
"""
self.extend_pe(x)
x = x * self.xscale
@ -409,9 +411,11 @@ class RelPositionalEncoding(torch.nn.Module):
]
return self.dropout(x), self.dropout(pos_emb)
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
Modified from
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
Args:
channels (int): The number of channels of conv layers.
@ -482,54 +486,6 @@ class ConvolutionModule(nn.Module):
return x.permute(2, 0, 1)
class ConvolutionModule_abs(nn.Module):
"""ConvolutionModule in Conformer model.
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers.
bias (bool): Whether to use bias in conv layers (default=True).
"""
def __init__(
self, channels: int, out_channels: int, kernel_size: int, padding: int, bias: bool = True
) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule_abs, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.conv1 = nn.Conv1d_abs(
channels,
out_channels,
kernel_size=kernel_size,
stride=1,
padding=padding,
bias=bias,
)
self.activation = nn.ReLU()
def forward(self, x: Tensor) -> Tensor:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
Returns:
Tensor: Output tensor (#time, batch, channels).
"""
# exchange the temporal dimension and the feature dimension
x = x.permute(1, 2, 0) # (#batch, channels, time).
x = self.conv1(x)
x = self.activation(x)
return x.permute(2, 0, 1)
class Swish(torch.nn.Module):
"""Construct an Swish object."""

View File

@ -0,0 +1,211 @@
# coding=utf-8
import math
import collections
from itertools import repeat
import torch
from torch import Tensor
from torch.nn.parameter import Parameter
from torch.nn import functional as F
from torch.nn import init
from torch.nn.modules.module import Module
from torch.nn.common_types import _size_1_t
from typing import Optional, Tuple
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return tuple(x)
return tuple(repeat(x, n))
return parse
def _reverse_repeat_tuple(t, n):
r"""Reverse the order of `t` and repeat each element for `n` times.
This can be used to translate padding arg used by Conv and Pooling modules
to the ones used by `F.pad`.
"""
return tuple(x for x in reversed(t) for _ in range(n))
_single = _ntuple(1)
_pair = _ntuple(2)
_triple = _ntuple(3)
_quadruple = _ntuple(4)
class _ConvNd(Module):
__constants__ = [
"stride",
"padding",
"dilation",
"groups",
"padding_mode",
"output_padding",
"in_channels",
"out_channels",
"kernel_size",
]
__annotations__ = {"bias": Optional[torch.Tensor]}
_in_channels: int
out_channels: int
kernel_size: Tuple[int, ...]
stride: Tuple[int, ...]
padding: Tuple[int, ...]
dilation: Tuple[int, ...]
transposed: bool
output_padding: Tuple[int, ...]
groups: int
padding_mode: str
weight: Tensor
bias: Optional[Tensor]
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t,
padding: _size_1_t,
dilation: _size_1_t,
transposed: bool,
output_padding: _size_1_t,
groups: int,
bias: Optional[Tensor],
padding_mode: str,
) -> None:
super(_ConvNd, self).__init__()
if in_channels % groups != 0:
raise ValueError("in_channels must be divisible by groups")
if out_channels % groups != 0:
raise ValueError("out_channels must be divisible by groups")
valid_padding_modes = {"zeros", "reflect", "replicate", "circular"}
if padding_mode not in valid_padding_modes:
raise ValueError(
"padding_mode must be one of {}, but got padding_mode='{}'".format(
valid_padding_modes, padding_mode
)
)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.transposed = transposed
self.output_padding = output_padding
self.groups = groups
self.padding_mode = padding_mode
# `_reversed_padding_repeated_twice` is the padding to be passed to
# `F.pad` if needed (e.g., for non-zero padding types that are
# implemented as two ops: padding + conv). `F.pad` accepts paddings in
# reverse order than the dimension.
self._reversed_padding_repeated_twice = _reverse_repeat_tuple(
self.padding, 2
)
if transposed:
self.weight = Parameter(
torch.Tensor(in_channels, out_channels // groups, *kernel_size)
)
else:
self.weight = Parameter(
torch.Tensor(out_channels, in_channels // groups, *kernel_size)
)
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self) -> None:
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def extra_repr(self):
s = (
"{in_channels}, {out_channels}, kernel_size={kernel_size}"
", stride={stride}"
)
if self.padding != (0,) * len(self.padding):
s += ", padding={padding}"
if self.dilation != (1,) * len(self.dilation):
s += ", dilation={dilation}"
if self.output_padding != (0,) * len(self.output_padding):
s += ", output_padding={output_padding}"
if self.groups != 1:
s += ", groups={groups}"
if self.bias is None:
s += ", bias=False"
if self.padding_mode != "zeros":
s += ", padding_mode={padding_mode}"
return s.format(**self.__dict__)
def __setstate__(self, state):
super(_ConvNd, self).__setstate__(state)
if not hasattr(self, "padding_mode"):
self.padding_mode = "zeros"
class Conv1dAbs(_ConvNd):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
padding: _size_1_t = 0,
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
):
kernel_size = _single(kernel_size)
stride = _single(stride)
padding = _single(padding)
dilation = _single(dilation)
super(Conv1dAbs, self).__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
False,
_single(0),
groups,
bias,
padding_mode,
)
def forward(self, input: Tensor) -> Tensor:
if self.padding_mode != "zeros":
return F.conv1d(
F.pad(
input,
self._reversed_padding_repeated_twice,
mode=self.padding_mode,
),
torch.abs(self.weight),
torch.abs(self.bias),
self.stride,
_single(0),
self.dilation,
self.groups,
)
return F.conv1d(
input,
torch.abs(self.weight),
torch.abs(self.bias),
self.stride,
self.padding,
self.dilation,
self.groups,
)