mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Update the modified attention codes
This commit is contained in:
parent
e442369987
commit
4392da7235
@ -17,11 +17,11 @@
|
||||
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from conv1d_abs_attention import Conv1dAbs
|
||||
from transformer import Supervisions, Transformer, encoder_padding_mask
|
||||
|
||||
|
||||
@ -178,8 +178,8 @@ class ConformerEncoderLayer(nn.Module):
|
||||
d_model
|
||||
) # for the macaron style FNN module
|
||||
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
||||
|
||||
#define layernorm for conv1d_abs
|
||||
|
||||
# define layernorm for conv1d_abs
|
||||
self.norm_conv_abs = nn.LayerNorm(d_model)
|
||||
|
||||
self.ff_scale = 0.5
|
||||
@ -194,7 +194,8 @@ class ConformerEncoderLayer(nn.Module):
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
self.linear1 = nn.Linear(512, 1024)
|
||||
self.conv1d_abs = ConvolutionModule_abs(1024, 64, kernel_size=21, padding=10)
|
||||
self.conv1d_abs = Conv1dAbs(1024, 64, kernel_size=21, padding=10)
|
||||
self.activation = nn.ReLU()
|
||||
self.linear2 = nn.Linear(64, 512)
|
||||
|
||||
def forward(
|
||||
@ -231,13 +232,15 @@ class ConformerEncoderLayer(nn.Module):
|
||||
if not self.normalize_before:
|
||||
src = self.norm_ff_macaron(src)
|
||||
|
||||
# modified-attention module
|
||||
# modified-attention module
|
||||
residual = src
|
||||
if self.normalize_before:
|
||||
src = self.norm_conv_abs(src)
|
||||
src = self.linear1(src)
|
||||
src = torch.exp(src.clamp(max=75))
|
||||
src = src.permute(1, 2, 0)
|
||||
src = self.conv1d_abs(src)
|
||||
src = self.activation(src).permute(2, 0, 1)
|
||||
src = torch.log(src)
|
||||
src = self.linear2(src)
|
||||
src = residual + self.dropout(src)
|
||||
@ -396,7 +399,6 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
x = x * self.xscale
|
||||
@ -409,9 +411,11 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
]
|
||||
return self.dropout(x), self.dropout(pos_emb)
|
||||
|
||||
|
||||
class ConvolutionModule(nn.Module):
|
||||
"""ConvolutionModule in Conformer model.
|
||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
|
||||
Modified from
|
||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
|
||||
|
||||
Args:
|
||||
channels (int): The number of channels of conv layers.
|
||||
@ -482,54 +486,6 @@ class ConvolutionModule(nn.Module):
|
||||
return x.permute(2, 0, 1)
|
||||
|
||||
|
||||
class ConvolutionModule_abs(nn.Module):
|
||||
"""ConvolutionModule in Conformer model.
|
||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
|
||||
|
||||
Args:
|
||||
channels (int): The number of channels of conv layers.
|
||||
kernel_size (int): Kernerl size of conv layers.
|
||||
bias (bool): Whether to use bias in conv layers (default=True).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, channels: int, out_channels: int, kernel_size: int, padding: int, bias: bool = True
|
||||
) -> None:
|
||||
"""Construct an ConvolutionModule object."""
|
||||
super(ConvolutionModule_abs, self).__init__()
|
||||
# kernerl_size should be a odd number for 'SAME' padding
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
|
||||
self.conv1 = nn.Conv1d_abs(
|
||||
channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=1,
|
||||
padding=padding,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.activation = nn.ReLU()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Compute convolution module.
|
||||
|
||||
Args:
|
||||
x: Input tensor (#time, batch, channels).
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (#time, batch, channels).
|
||||
|
||||
"""
|
||||
# exchange the temporal dimension and the feature dimension
|
||||
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
||||
x = self.conv1(x)
|
||||
x = self.activation(x)
|
||||
|
||||
return x.permute(2, 0, 1)
|
||||
|
||||
|
||||
class Swish(torch.nn.Module):
|
||||
"""Construct an Swish object."""
|
||||
|
||||
|
211
egs/librispeech/ASR/conformer_ctc/conv1d_abs_attention.py
Normal file
211
egs/librispeech/ASR/conformer_ctc/conv1d_abs_attention.py
Normal file
@ -0,0 +1,211 @@
|
||||
# coding=utf-8
|
||||
import math
|
||||
import collections
|
||||
from itertools import repeat
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn import functional as F
|
||||
from torch.nn import init
|
||||
from torch.nn.modules.module import Module
|
||||
from torch.nn.common_types import _size_1_t
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return tuple(x)
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
def _reverse_repeat_tuple(t, n):
|
||||
r"""Reverse the order of `t` and repeat each element for `n` times.
|
||||
|
||||
This can be used to translate padding arg used by Conv and Pooling modules
|
||||
to the ones used by `F.pad`.
|
||||
"""
|
||||
return tuple(x for x in reversed(t) for _ in range(n))
|
||||
|
||||
|
||||
_single = _ntuple(1)
|
||||
_pair = _ntuple(2)
|
||||
_triple = _ntuple(3)
|
||||
_quadruple = _ntuple(4)
|
||||
|
||||
|
||||
class _ConvNd(Module):
|
||||
|
||||
__constants__ = [
|
||||
"stride",
|
||||
"padding",
|
||||
"dilation",
|
||||
"groups",
|
||||
"padding_mode",
|
||||
"output_padding",
|
||||
"in_channels",
|
||||
"out_channels",
|
||||
"kernel_size",
|
||||
]
|
||||
__annotations__ = {"bias": Optional[torch.Tensor]}
|
||||
|
||||
_in_channels: int
|
||||
out_channels: int
|
||||
kernel_size: Tuple[int, ...]
|
||||
stride: Tuple[int, ...]
|
||||
padding: Tuple[int, ...]
|
||||
dilation: Tuple[int, ...]
|
||||
transposed: bool
|
||||
output_padding: Tuple[int, ...]
|
||||
groups: int
|
||||
padding_mode: str
|
||||
weight: Tensor
|
||||
bias: Optional[Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: _size_1_t,
|
||||
stride: _size_1_t,
|
||||
padding: _size_1_t,
|
||||
dilation: _size_1_t,
|
||||
transposed: bool,
|
||||
output_padding: _size_1_t,
|
||||
groups: int,
|
||||
bias: Optional[Tensor],
|
||||
padding_mode: str,
|
||||
) -> None:
|
||||
super(_ConvNd, self).__init__()
|
||||
if in_channels % groups != 0:
|
||||
raise ValueError("in_channels must be divisible by groups")
|
||||
if out_channels % groups != 0:
|
||||
raise ValueError("out_channels must be divisible by groups")
|
||||
valid_padding_modes = {"zeros", "reflect", "replicate", "circular"}
|
||||
if padding_mode not in valid_padding_modes:
|
||||
raise ValueError(
|
||||
"padding_mode must be one of {}, but got padding_mode='{}'".format(
|
||||
valid_padding_modes, padding_mode
|
||||
)
|
||||
)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.dilation = dilation
|
||||
self.transposed = transposed
|
||||
self.output_padding = output_padding
|
||||
self.groups = groups
|
||||
self.padding_mode = padding_mode
|
||||
# `_reversed_padding_repeated_twice` is the padding to be passed to
|
||||
# `F.pad` if needed (e.g., for non-zero padding types that are
|
||||
# implemented as two ops: padding + conv). `F.pad` accepts paddings in
|
||||
# reverse order than the dimension.
|
||||
self._reversed_padding_repeated_twice = _reverse_repeat_tuple(
|
||||
self.padding, 2
|
||||
)
|
||||
if transposed:
|
||||
self.weight = Parameter(
|
||||
torch.Tensor(in_channels, out_channels // groups, *kernel_size)
|
||||
)
|
||||
else:
|
||||
self.weight = Parameter(
|
||||
torch.Tensor(out_channels, in_channels // groups, *kernel_size)
|
||||
)
|
||||
if bias:
|
||||
self.bias = Parameter(torch.Tensor(out_channels))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
if self.bias is not None:
|
||||
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
init.uniform_(self.bias, -bound, bound)
|
||||
|
||||
def extra_repr(self):
|
||||
s = (
|
||||
"{in_channels}, {out_channels}, kernel_size={kernel_size}"
|
||||
", stride={stride}"
|
||||
)
|
||||
if self.padding != (0,) * len(self.padding):
|
||||
s += ", padding={padding}"
|
||||
if self.dilation != (1,) * len(self.dilation):
|
||||
s += ", dilation={dilation}"
|
||||
if self.output_padding != (0,) * len(self.output_padding):
|
||||
s += ", output_padding={output_padding}"
|
||||
if self.groups != 1:
|
||||
s += ", groups={groups}"
|
||||
if self.bias is None:
|
||||
s += ", bias=False"
|
||||
if self.padding_mode != "zeros":
|
||||
s += ", padding_mode={padding_mode}"
|
||||
return s.format(**self.__dict__)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(_ConvNd, self).__setstate__(state)
|
||||
if not hasattr(self, "padding_mode"):
|
||||
self.padding_mode = "zeros"
|
||||
|
||||
|
||||
class Conv1dAbs(_ConvNd):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: _size_1_t,
|
||||
stride: _size_1_t = 1,
|
||||
padding: _size_1_t = 0,
|
||||
dilation: _size_1_t = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: str = "zeros",
|
||||
):
|
||||
kernel_size = _single(kernel_size)
|
||||
stride = _single(stride)
|
||||
padding = _single(padding)
|
||||
dilation = _single(dilation)
|
||||
super(Conv1dAbs, self).__init__(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
False,
|
||||
_single(0),
|
||||
groups,
|
||||
bias,
|
||||
padding_mode,
|
||||
)
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
if self.padding_mode != "zeros":
|
||||
return F.conv1d(
|
||||
F.pad(
|
||||
input,
|
||||
self._reversed_padding_repeated_twice,
|
||||
mode=self.padding_mode,
|
||||
),
|
||||
torch.abs(self.weight),
|
||||
torch.abs(self.bias),
|
||||
self.stride,
|
||||
_single(0),
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
return F.conv1d(
|
||||
input,
|
||||
torch.abs(self.weight),
|
||||
torch.abs(self.bias),
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
Loading…
x
Reference in New Issue
Block a user