Update the modified attention codes

This commit is contained in:
Mingshuang Luo 2021-12-13 15:15:15 +08:00
parent e442369987
commit 4392da7235
2 changed files with 222 additions and 55 deletions

View File

@ -17,11 +17,11 @@
import math import math
import warnings
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from conv1d_abs_attention import Conv1dAbs
from transformer import Supervisions, Transformer, encoder_padding_mask from transformer import Supervisions, Transformer, encoder_padding_mask
@ -194,7 +194,8 @@ class ConformerEncoderLayer(nn.Module):
self.normalize_before = normalize_before self.normalize_before = normalize_before
self.linear1 = nn.Linear(512, 1024) self.linear1 = nn.Linear(512, 1024)
self.conv1d_abs = ConvolutionModule_abs(1024, 64, kernel_size=21, padding=10) self.conv1d_abs = Conv1dAbs(1024, 64, kernel_size=21, padding=10)
self.activation = nn.ReLU()
self.linear2 = nn.Linear(64, 512) self.linear2 = nn.Linear(64, 512)
def forward( def forward(
@ -237,7 +238,9 @@ class ConformerEncoderLayer(nn.Module):
src = self.norm_conv_abs(src) src = self.norm_conv_abs(src)
src = self.linear1(src) src = self.linear1(src)
src = torch.exp(src.clamp(max=75)) src = torch.exp(src.clamp(max=75))
src = src.permute(1, 2, 0)
src = self.conv1d_abs(src) src = self.conv1d_abs(src)
src = self.activation(src).permute(2, 0, 1)
src = torch.log(src) src = torch.log(src)
src = self.linear2(src) src = self.linear2(src)
src = residual + self.dropout(src) src = residual + self.dropout(src)
@ -396,7 +399,6 @@ class RelPositionalEncoding(torch.nn.Module):
Returns: Returns:
torch.Tensor: Encoded tensor (batch, time, `*`). torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
""" """
self.extend_pe(x) self.extend_pe(x)
x = x * self.xscale x = x * self.xscale
@ -409,9 +411,11 @@ class RelPositionalEncoding(torch.nn.Module):
] ]
return self.dropout(x), self.dropout(pos_emb) return self.dropout(x), self.dropout(pos_emb)
class ConvolutionModule(nn.Module): class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model. """ConvolutionModule in Conformer model.
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py Modified from
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
Args: Args:
channels (int): The number of channels of conv layers. channels (int): The number of channels of conv layers.
@ -482,54 +486,6 @@ class ConvolutionModule(nn.Module):
return x.permute(2, 0, 1) return x.permute(2, 0, 1)
class ConvolutionModule_abs(nn.Module):
"""ConvolutionModule in Conformer model.
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers.
bias (bool): Whether to use bias in conv layers (default=True).
"""
def __init__(
self, channels: int, out_channels: int, kernel_size: int, padding: int, bias: bool = True
) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule_abs, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.conv1 = nn.Conv1d_abs(
channels,
out_channels,
kernel_size=kernel_size,
stride=1,
padding=padding,
bias=bias,
)
self.activation = nn.ReLU()
def forward(self, x: Tensor) -> Tensor:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
Returns:
Tensor: Output tensor (#time, batch, channels).
"""
# exchange the temporal dimension and the feature dimension
x = x.permute(1, 2, 0) # (#batch, channels, time).
x = self.conv1(x)
x = self.activation(x)
return x.permute(2, 0, 1)
class Swish(torch.nn.Module): class Swish(torch.nn.Module):
"""Construct an Swish object.""" """Construct an Swish object."""

View File

@ -0,0 +1,211 @@
# coding=utf-8
import math
import collections
from itertools import repeat
import torch
from torch import Tensor
from torch.nn.parameter import Parameter
from torch.nn import functional as F
from torch.nn import init
from torch.nn.modules.module import Module
from torch.nn.common_types import _size_1_t
from typing import Optional, Tuple
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return tuple(x)
return tuple(repeat(x, n))
return parse
def _reverse_repeat_tuple(t, n):
r"""Reverse the order of `t` and repeat each element for `n` times.
This can be used to translate padding arg used by Conv and Pooling modules
to the ones used by `F.pad`.
"""
return tuple(x for x in reversed(t) for _ in range(n))
_single = _ntuple(1)
_pair = _ntuple(2)
_triple = _ntuple(3)
_quadruple = _ntuple(4)
class _ConvNd(Module):
__constants__ = [
"stride",
"padding",
"dilation",
"groups",
"padding_mode",
"output_padding",
"in_channels",
"out_channels",
"kernel_size",
]
__annotations__ = {"bias": Optional[torch.Tensor]}
_in_channels: int
out_channels: int
kernel_size: Tuple[int, ...]
stride: Tuple[int, ...]
padding: Tuple[int, ...]
dilation: Tuple[int, ...]
transposed: bool
output_padding: Tuple[int, ...]
groups: int
padding_mode: str
weight: Tensor
bias: Optional[Tensor]
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t,
padding: _size_1_t,
dilation: _size_1_t,
transposed: bool,
output_padding: _size_1_t,
groups: int,
bias: Optional[Tensor],
padding_mode: str,
) -> None:
super(_ConvNd, self).__init__()
if in_channels % groups != 0:
raise ValueError("in_channels must be divisible by groups")
if out_channels % groups != 0:
raise ValueError("out_channels must be divisible by groups")
valid_padding_modes = {"zeros", "reflect", "replicate", "circular"}
if padding_mode not in valid_padding_modes:
raise ValueError(
"padding_mode must be one of {}, but got padding_mode='{}'".format(
valid_padding_modes, padding_mode
)
)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.transposed = transposed
self.output_padding = output_padding
self.groups = groups
self.padding_mode = padding_mode
# `_reversed_padding_repeated_twice` is the padding to be passed to
# `F.pad` if needed (e.g., for non-zero padding types that are
# implemented as two ops: padding + conv). `F.pad` accepts paddings in
# reverse order than the dimension.
self._reversed_padding_repeated_twice = _reverse_repeat_tuple(
self.padding, 2
)
if transposed:
self.weight = Parameter(
torch.Tensor(in_channels, out_channels // groups, *kernel_size)
)
else:
self.weight = Parameter(
torch.Tensor(out_channels, in_channels // groups, *kernel_size)
)
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self) -> None:
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def extra_repr(self):
s = (
"{in_channels}, {out_channels}, kernel_size={kernel_size}"
", stride={stride}"
)
if self.padding != (0,) * len(self.padding):
s += ", padding={padding}"
if self.dilation != (1,) * len(self.dilation):
s += ", dilation={dilation}"
if self.output_padding != (0,) * len(self.output_padding):
s += ", output_padding={output_padding}"
if self.groups != 1:
s += ", groups={groups}"
if self.bias is None:
s += ", bias=False"
if self.padding_mode != "zeros":
s += ", padding_mode={padding_mode}"
return s.format(**self.__dict__)
def __setstate__(self, state):
super(_ConvNd, self).__setstate__(state)
if not hasattr(self, "padding_mode"):
self.padding_mode = "zeros"
class Conv1dAbs(_ConvNd):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
padding: _size_1_t = 0,
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
):
kernel_size = _single(kernel_size)
stride = _single(stride)
padding = _single(padding)
dilation = _single(dilation)
super(Conv1dAbs, self).__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
False,
_single(0),
groups,
bias,
padding_mode,
)
def forward(self, input: Tensor) -> Tensor:
if self.padding_mode != "zeros":
return F.conv1d(
F.pad(
input,
self._reversed_padding_repeated_twice,
mode=self.padding_mode,
),
torch.abs(self.weight),
torch.abs(self.bias),
self.stride,
_single(0),
self.dilation,
self.groups,
)
return F.conv1d(
input,
torch.abs(self.weight),
torch.abs(self.bias),
self.stride,
self.padding,
self.dilation,
self.groups,
)