mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
Update the modified attention codes
This commit is contained in:
parent
e442369987
commit
4392da7235
@ -17,11 +17,11 @@
|
|||||||
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import warnings
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
from conv1d_abs_attention import Conv1dAbs
|
||||||
from transformer import Supervisions, Transformer, encoder_padding_mask
|
from transformer import Supervisions, Transformer, encoder_padding_mask
|
||||||
|
|
||||||
|
|
||||||
@ -179,7 +179,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
) # for the macaron style FNN module
|
) # for the macaron style FNN module
|
||||||
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
||||||
|
|
||||||
#define layernorm for conv1d_abs
|
# define layernorm for conv1d_abs
|
||||||
self.norm_conv_abs = nn.LayerNorm(d_model)
|
self.norm_conv_abs = nn.LayerNorm(d_model)
|
||||||
|
|
||||||
self.ff_scale = 0.5
|
self.ff_scale = 0.5
|
||||||
@ -194,7 +194,8 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
self.normalize_before = normalize_before
|
self.normalize_before = normalize_before
|
||||||
|
|
||||||
self.linear1 = nn.Linear(512, 1024)
|
self.linear1 = nn.Linear(512, 1024)
|
||||||
self.conv1d_abs = ConvolutionModule_abs(1024, 64, kernel_size=21, padding=10)
|
self.conv1d_abs = Conv1dAbs(1024, 64, kernel_size=21, padding=10)
|
||||||
|
self.activation = nn.ReLU()
|
||||||
self.linear2 = nn.Linear(64, 512)
|
self.linear2 = nn.Linear(64, 512)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -237,7 +238,9 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
src = self.norm_conv_abs(src)
|
src = self.norm_conv_abs(src)
|
||||||
src = self.linear1(src)
|
src = self.linear1(src)
|
||||||
src = torch.exp(src.clamp(max=75))
|
src = torch.exp(src.clamp(max=75))
|
||||||
|
src = src.permute(1, 2, 0)
|
||||||
src = self.conv1d_abs(src)
|
src = self.conv1d_abs(src)
|
||||||
|
src = self.activation(src).permute(2, 0, 1)
|
||||||
src = torch.log(src)
|
src = torch.log(src)
|
||||||
src = self.linear2(src)
|
src = self.linear2(src)
|
||||||
src = residual + self.dropout(src)
|
src = residual + self.dropout(src)
|
||||||
@ -396,7 +399,6 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||||
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.extend_pe(x)
|
self.extend_pe(x)
|
||||||
x = x * self.xscale
|
x = x * self.xscale
|
||||||
@ -409,9 +411,11 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
]
|
]
|
||||||
return self.dropout(x), self.dropout(pos_emb)
|
return self.dropout(x), self.dropout(pos_emb)
|
||||||
|
|
||||||
|
|
||||||
class ConvolutionModule(nn.Module):
|
class ConvolutionModule(nn.Module):
|
||||||
"""ConvolutionModule in Conformer model.
|
"""ConvolutionModule in Conformer model.
|
||||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
|
Modified from
|
||||||
|
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
channels (int): The number of channels of conv layers.
|
channels (int): The number of channels of conv layers.
|
||||||
@ -482,54 +486,6 @@ class ConvolutionModule(nn.Module):
|
|||||||
return x.permute(2, 0, 1)
|
return x.permute(2, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
class ConvolutionModule_abs(nn.Module):
|
|
||||||
"""ConvolutionModule in Conformer model.
|
|
||||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channels (int): The number of channels of conv layers.
|
|
||||||
kernel_size (int): Kernerl size of conv layers.
|
|
||||||
bias (bool): Whether to use bias in conv layers (default=True).
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, channels: int, out_channels: int, kernel_size: int, padding: int, bias: bool = True
|
|
||||||
) -> None:
|
|
||||||
"""Construct an ConvolutionModule object."""
|
|
||||||
super(ConvolutionModule_abs, self).__init__()
|
|
||||||
# kernerl_size should be a odd number for 'SAME' padding
|
|
||||||
assert (kernel_size - 1) % 2 == 0
|
|
||||||
|
|
||||||
self.conv1 = nn.Conv1d_abs(
|
|
||||||
channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
stride=1,
|
|
||||||
padding=padding,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.activation = nn.ReLU()
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
"""Compute convolution module.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x: Input tensor (#time, batch, channels).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Output tensor (#time, batch, channels).
|
|
||||||
|
|
||||||
"""
|
|
||||||
# exchange the temporal dimension and the feature dimension
|
|
||||||
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
|
||||||
x = self.conv1(x)
|
|
||||||
x = self.activation(x)
|
|
||||||
|
|
||||||
return x.permute(2, 0, 1)
|
|
||||||
|
|
||||||
|
|
||||||
class Swish(torch.nn.Module):
|
class Swish(torch.nn.Module):
|
||||||
"""Construct an Swish object."""
|
"""Construct an Swish object."""
|
||||||
|
|
||||||
|
211
egs/librispeech/ASR/conformer_ctc/conv1d_abs_attention.py
Normal file
211
egs/librispeech/ASR/conformer_ctc/conv1d_abs_attention.py
Normal file
@ -0,0 +1,211 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
import math
|
||||||
|
import collections
|
||||||
|
from itertools import repeat
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.nn import init
|
||||||
|
from torch.nn.modules.module import Module
|
||||||
|
from torch.nn.common_types import _size_1_t
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
def _ntuple(n):
|
||||||
|
def parse(x):
|
||||||
|
if isinstance(x, collections.abc.Iterable):
|
||||||
|
return tuple(x)
|
||||||
|
return tuple(repeat(x, n))
|
||||||
|
|
||||||
|
return parse
|
||||||
|
|
||||||
|
|
||||||
|
def _reverse_repeat_tuple(t, n):
|
||||||
|
r"""Reverse the order of `t` and repeat each element for `n` times.
|
||||||
|
|
||||||
|
This can be used to translate padding arg used by Conv and Pooling modules
|
||||||
|
to the ones used by `F.pad`.
|
||||||
|
"""
|
||||||
|
return tuple(x for x in reversed(t) for _ in range(n))
|
||||||
|
|
||||||
|
|
||||||
|
_single = _ntuple(1)
|
||||||
|
_pair = _ntuple(2)
|
||||||
|
_triple = _ntuple(3)
|
||||||
|
_quadruple = _ntuple(4)
|
||||||
|
|
||||||
|
|
||||||
|
class _ConvNd(Module):
|
||||||
|
|
||||||
|
__constants__ = [
|
||||||
|
"stride",
|
||||||
|
"padding",
|
||||||
|
"dilation",
|
||||||
|
"groups",
|
||||||
|
"padding_mode",
|
||||||
|
"output_padding",
|
||||||
|
"in_channels",
|
||||||
|
"out_channels",
|
||||||
|
"kernel_size",
|
||||||
|
]
|
||||||
|
__annotations__ = {"bias": Optional[torch.Tensor]}
|
||||||
|
|
||||||
|
_in_channels: int
|
||||||
|
out_channels: int
|
||||||
|
kernel_size: Tuple[int, ...]
|
||||||
|
stride: Tuple[int, ...]
|
||||||
|
padding: Tuple[int, ...]
|
||||||
|
dilation: Tuple[int, ...]
|
||||||
|
transposed: bool
|
||||||
|
output_padding: Tuple[int, ...]
|
||||||
|
groups: int
|
||||||
|
padding_mode: str
|
||||||
|
weight: Tensor
|
||||||
|
bias: Optional[Tensor]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: _size_1_t,
|
||||||
|
stride: _size_1_t,
|
||||||
|
padding: _size_1_t,
|
||||||
|
dilation: _size_1_t,
|
||||||
|
transposed: bool,
|
||||||
|
output_padding: _size_1_t,
|
||||||
|
groups: int,
|
||||||
|
bias: Optional[Tensor],
|
||||||
|
padding_mode: str,
|
||||||
|
) -> None:
|
||||||
|
super(_ConvNd, self).__init__()
|
||||||
|
if in_channels % groups != 0:
|
||||||
|
raise ValueError("in_channels must be divisible by groups")
|
||||||
|
if out_channels % groups != 0:
|
||||||
|
raise ValueError("out_channels must be divisible by groups")
|
||||||
|
valid_padding_modes = {"zeros", "reflect", "replicate", "circular"}
|
||||||
|
if padding_mode not in valid_padding_modes:
|
||||||
|
raise ValueError(
|
||||||
|
"padding_mode must be one of {}, but got padding_mode='{}'".format(
|
||||||
|
valid_padding_modes, padding_mode
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.stride = stride
|
||||||
|
self.padding = padding
|
||||||
|
self.dilation = dilation
|
||||||
|
self.transposed = transposed
|
||||||
|
self.output_padding = output_padding
|
||||||
|
self.groups = groups
|
||||||
|
self.padding_mode = padding_mode
|
||||||
|
# `_reversed_padding_repeated_twice` is the padding to be passed to
|
||||||
|
# `F.pad` if needed (e.g., for non-zero padding types that are
|
||||||
|
# implemented as two ops: padding + conv). `F.pad` accepts paddings in
|
||||||
|
# reverse order than the dimension.
|
||||||
|
self._reversed_padding_repeated_twice = _reverse_repeat_tuple(
|
||||||
|
self.padding, 2
|
||||||
|
)
|
||||||
|
if transposed:
|
||||||
|
self.weight = Parameter(
|
||||||
|
torch.Tensor(in_channels, out_channels // groups, *kernel_size)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.weight = Parameter(
|
||||||
|
torch.Tensor(out_channels, in_channels // groups, *kernel_size)
|
||||||
|
)
|
||||||
|
if bias:
|
||||||
|
self.bias = Parameter(torch.Tensor(out_channels))
|
||||||
|
else:
|
||||||
|
self.register_parameter("bias", None)
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self) -> None:
|
||||||
|
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||||
|
if self.bias is not None:
|
||||||
|
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
|
||||||
|
bound = 1 / math.sqrt(fan_in)
|
||||||
|
init.uniform_(self.bias, -bound, bound)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
s = (
|
||||||
|
"{in_channels}, {out_channels}, kernel_size={kernel_size}"
|
||||||
|
", stride={stride}"
|
||||||
|
)
|
||||||
|
if self.padding != (0,) * len(self.padding):
|
||||||
|
s += ", padding={padding}"
|
||||||
|
if self.dilation != (1,) * len(self.dilation):
|
||||||
|
s += ", dilation={dilation}"
|
||||||
|
if self.output_padding != (0,) * len(self.output_padding):
|
||||||
|
s += ", output_padding={output_padding}"
|
||||||
|
if self.groups != 1:
|
||||||
|
s += ", groups={groups}"
|
||||||
|
if self.bias is None:
|
||||||
|
s += ", bias=False"
|
||||||
|
if self.padding_mode != "zeros":
|
||||||
|
s += ", padding_mode={padding_mode}"
|
||||||
|
return s.format(**self.__dict__)
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
super(_ConvNd, self).__setstate__(state)
|
||||||
|
if not hasattr(self, "padding_mode"):
|
||||||
|
self.padding_mode = "zeros"
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1dAbs(_ConvNd):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: _size_1_t,
|
||||||
|
stride: _size_1_t = 1,
|
||||||
|
padding: _size_1_t = 0,
|
||||||
|
dilation: _size_1_t = 1,
|
||||||
|
groups: int = 1,
|
||||||
|
bias: bool = True,
|
||||||
|
padding_mode: str = "zeros",
|
||||||
|
):
|
||||||
|
kernel_size = _single(kernel_size)
|
||||||
|
stride = _single(stride)
|
||||||
|
padding = _single(padding)
|
||||||
|
dilation = _single(dilation)
|
||||||
|
super(Conv1dAbs, self).__init__(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation,
|
||||||
|
False,
|
||||||
|
_single(0),
|
||||||
|
groups,
|
||||||
|
bias,
|
||||||
|
padding_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
|
if self.padding_mode != "zeros":
|
||||||
|
return F.conv1d(
|
||||||
|
F.pad(
|
||||||
|
input,
|
||||||
|
self._reversed_padding_repeated_twice,
|
||||||
|
mode=self.padding_mode,
|
||||||
|
),
|
||||||
|
torch.abs(self.weight),
|
||||||
|
torch.abs(self.bias),
|
||||||
|
self.stride,
|
||||||
|
_single(0),
|
||||||
|
self.dilation,
|
||||||
|
self.groups,
|
||||||
|
)
|
||||||
|
return F.conv1d(
|
||||||
|
input,
|
||||||
|
torch.abs(self.weight),
|
||||||
|
torch.abs(self.bias),
|
||||||
|
self.stride,
|
||||||
|
self.padding,
|
||||||
|
self.dilation,
|
||||||
|
self.groups,
|
||||||
|
)
|
Loading…
x
Reference in New Issue
Block a user