From 4392da723543d1008e06a2d7bfacc4a648e6e7ac Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Mon, 13 Dec 2021 15:15:15 +0800 Subject: [PATCH] Update the modified attention codes --- .../ASR/conformer_ctc/conformer.py | 66 +----- .../ASR/conformer_ctc/conv1d_abs_attention.py | 211 ++++++++++++++++++ 2 files changed, 222 insertions(+), 55 deletions(-) create mode 100644 egs/librispeech/ASR/conformer_ctc/conv1d_abs_attention.py diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index d8b02cc05..91f8cf694 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -17,11 +17,11 @@ import math -import warnings from typing import Optional, Tuple import torch from torch import Tensor, nn +from conv1d_abs_attention import Conv1dAbs from transformer import Supervisions, Transformer, encoder_padding_mask @@ -178,8 +178,8 @@ class ConformerEncoderLayer(nn.Module): d_model ) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module - - #define layernorm for conv1d_abs + + # define layernorm for conv1d_abs self.norm_conv_abs = nn.LayerNorm(d_model) self.ff_scale = 0.5 @@ -194,7 +194,8 @@ class ConformerEncoderLayer(nn.Module): self.normalize_before = normalize_before self.linear1 = nn.Linear(512, 1024) - self.conv1d_abs = ConvolutionModule_abs(1024, 64, kernel_size=21, padding=10) + self.conv1d_abs = Conv1dAbs(1024, 64, kernel_size=21, padding=10) + self.activation = nn.ReLU() self.linear2 = nn.Linear(64, 512) def forward( @@ -231,13 +232,15 @@ class ConformerEncoderLayer(nn.Module): if not self.normalize_before: src = self.norm_ff_macaron(src) - # modified-attention module + # modified-attention module residual = src if self.normalize_before: src = self.norm_conv_abs(src) src = self.linear1(src) src = torch.exp(src.clamp(max=75)) + src = src.permute(1, 2, 0) src = self.conv1d_abs(src) + src = self.activation(src).permute(2, 0, 1) src = torch.log(src) src = self.linear2(src) src = residual + self.dropout(src) @@ -396,7 +399,6 @@ class RelPositionalEncoding(torch.nn.Module): Returns: torch.Tensor: Encoded tensor (batch, time, `*`). torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). - """ self.extend_pe(x) x = x * self.xscale @@ -409,9 +411,11 @@ class RelPositionalEncoding(torch.nn.Module): ] return self.dropout(x), self.dropout(pos_emb) + class ConvolutionModule(nn.Module): """ConvolutionModule in Conformer model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py Args: channels (int): The number of channels of conv layers. @@ -482,54 +486,6 @@ class ConvolutionModule(nn.Module): return x.permute(2, 0, 1) -class ConvolutionModule_abs(nn.Module): - """ConvolutionModule in Conformer model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - - """ - - def __init__( - self, channels: int, out_channels: int, kernel_size: int, padding: int, bias: bool = True - ) -> None: - """Construct an ConvolutionModule object.""" - super(ConvolutionModule_abs, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0 - - self.conv1 = nn.Conv1d_abs( - channels, - out_channels, - kernel_size=kernel_size, - stride=1, - padding=padding, - bias=bias, - ) - - self.activation = nn.ReLU() - - def forward(self, x: Tensor) -> Tensor: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - - Returns: - Tensor: Output tensor (#time, batch, channels). - - """ - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - x = self.conv1(x) - x = self.activation(x) - - return x.permute(2, 0, 1) - - class Swish(torch.nn.Module): """Construct an Swish object.""" diff --git a/egs/librispeech/ASR/conformer_ctc/conv1d_abs_attention.py b/egs/librispeech/ASR/conformer_ctc/conv1d_abs_attention.py new file mode 100644 index 000000000..031868628 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/conv1d_abs_attention.py @@ -0,0 +1,211 @@ +# coding=utf-8 +import math +import collections +from itertools import repeat + +import torch +from torch import Tensor +from torch.nn.parameter import Parameter +from torch.nn import functional as F +from torch.nn import init +from torch.nn.modules.module import Module +from torch.nn.common_types import _size_1_t +from typing import Optional, Tuple + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +def _reverse_repeat_tuple(t, n): + r"""Reverse the order of `t` and repeat each element for `n` times. + + This can be used to translate padding arg used by Conv and Pooling modules + to the ones used by `F.pad`. + """ + return tuple(x for x in reversed(t) for _ in range(n)) + + +_single = _ntuple(1) +_pair = _ntuple(2) +_triple = _ntuple(3) +_quadruple = _ntuple(4) + + +class _ConvNd(Module): + + __constants__ = [ + "stride", + "padding", + "dilation", + "groups", + "padding_mode", + "output_padding", + "in_channels", + "out_channels", + "kernel_size", + ] + __annotations__ = {"bias": Optional[torch.Tensor]} + + _in_channels: int + out_channels: int + kernel_size: Tuple[int, ...] + stride: Tuple[int, ...] + padding: Tuple[int, ...] + dilation: Tuple[int, ...] + transposed: bool + output_padding: Tuple[int, ...] + groups: int + padding_mode: str + weight: Tensor + bias: Optional[Tensor] + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t, + padding: _size_1_t, + dilation: _size_1_t, + transposed: bool, + output_padding: _size_1_t, + groups: int, + bias: Optional[Tensor], + padding_mode: str, + ) -> None: + super(_ConvNd, self).__init__() + if in_channels % groups != 0: + raise ValueError("in_channels must be divisible by groups") + if out_channels % groups != 0: + raise ValueError("out_channels must be divisible by groups") + valid_padding_modes = {"zeros", "reflect", "replicate", "circular"} + if padding_mode not in valid_padding_modes: + raise ValueError( + "padding_mode must be one of {}, but got padding_mode='{}'".format( + valid_padding_modes, padding_mode + ) + ) + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.transposed = transposed + self.output_padding = output_padding + self.groups = groups + self.padding_mode = padding_mode + # `_reversed_padding_repeated_twice` is the padding to be passed to + # `F.pad` if needed (e.g., for non-zero padding types that are + # implemented as two ops: padding + conv). `F.pad` accepts paddings in + # reverse order than the dimension. + self._reversed_padding_repeated_twice = _reverse_repeat_tuple( + self.padding, 2 + ) + if transposed: + self.weight = Parameter( + torch.Tensor(in_channels, out_channels // groups, *kernel_size) + ) + else: + self.weight = Parameter( + torch.Tensor(out_channels, in_channels // groups, *kernel_size) + ) + if bias: + self.bias = Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + init.uniform_(self.bias, -bound, bound) + + def extra_repr(self): + s = ( + "{in_channels}, {out_channels}, kernel_size={kernel_size}" + ", stride={stride}" + ) + if self.padding != (0,) * len(self.padding): + s += ", padding={padding}" + if self.dilation != (1,) * len(self.dilation): + s += ", dilation={dilation}" + if self.output_padding != (0,) * len(self.output_padding): + s += ", output_padding={output_padding}" + if self.groups != 1: + s += ", groups={groups}" + if self.bias is None: + s += ", bias=False" + if self.padding_mode != "zeros": + s += ", padding_mode={padding_mode}" + return s.format(**self.__dict__) + + def __setstate__(self, state): + super(_ConvNd, self).__setstate__(state) + if not hasattr(self, "padding_mode"): + self.padding_mode = "zeros" + + +class Conv1dAbs(_ConvNd): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + dilation: _size_1_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + ): + kernel_size = _single(kernel_size) + stride = _single(stride) + padding = _single(padding) + dilation = _single(dilation) + super(Conv1dAbs, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + False, + _single(0), + groups, + bias, + padding_mode, + ) + + def forward(self, input: Tensor) -> Tensor: + if self.padding_mode != "zeros": + return F.conv1d( + F.pad( + input, + self._reversed_padding_repeated_twice, + mode=self.padding_mode, + ), + torch.abs(self.weight), + torch.abs(self.bias), + self.stride, + _single(0), + self.dilation, + self.groups, + ) + return F.conv1d( + input, + torch.abs(self.weight), + torch.abs(self.bias), + self.stride, + self.padding, + self.dilation, + self.groups, + )