modified conv1dabs attention with pad 1

This commit is contained in:
Mingshuang Luo 2022-01-26 11:46:03 +08:00
parent ae667f4801
commit cbc6f01861
2 changed files with 42 additions and 39 deletions

View File

@ -23,6 +23,7 @@ import torch
from torch import Tensor, nn from torch import Tensor, nn
from conv1d_abs_attention import Conv1dAbs from conv1d_abs_attention import Conv1dAbs
from transformer import Supervisions, Transformer, encoder_padding_mask from transformer import Supervisions, Transformer, encoder_padding_mask
import logging
class Conformer(Transformer): class Conformer(Transformer):
@ -157,6 +158,7 @@ class ConformerEncoderLayer(nn.Module):
normalize_before: bool = True, normalize_before: bool = True,
) -> None: ) -> None:
super(ConformerEncoderLayer, self).__init__() super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention( self.self_attn = RelPositionMultiheadAttention(
d_model, nhead, dropout=0.0 d_model, nhead, dropout=0.0
) )
@ -180,6 +182,7 @@ class ConformerEncoderLayer(nn.Module):
d_model d_model
) # for the macaron style FNN module ) # for the macaron style FNN module
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
# define layernorm for conv1d_abs # define layernorm for conv1d_abs
@ -201,13 +204,10 @@ class ConformerEncoderLayer(nn.Module):
self.padding = int((self.kernel_size-1)/2) self.padding = int((self.kernel_size-1)/2)
self.in_conv1d_channels = 768 self.in_conv1d_channels = 768
self.out_conv1d_channels = 768 self.out_conv1d_channels = 768
# kernel size=21, self.conv1d_channels=768
# kernel size=5, self.conv1d_channels=1024
self.linear1 = nn.Linear(512, self.in_conv1d_channels) self.linear1 = nn.Linear(512, self.in_conv1d_channels)
self.conv1d_abs = Conv1dAbs( self.conv1d_abs = Conv1dAbs(self.in_conv1d_channels, self.out_conv1d_channels, kernel_size=self.kernel_size, padding=self.padding, padding_mode="replicate")
self.in_conv1d_channels,
self.out_conv1d_channels,
kernel_size=self.kernel_size,
padding=self.padding,
)
self.linear2 = nn.Linear(self.out_conv1d_channels, 512) self.linear2 = nn.Linear(self.out_conv1d_channels, 512)
def forward( def forward(
@ -244,7 +244,7 @@ class ConformerEncoderLayer(nn.Module):
if not self.normalize_before: if not self.normalize_before:
src = self.norm_ff_macaron(src) src = self.norm_ff_macaron(src)
# multi-head attention module # multi-head attention
residual = src residual = src
if self.normalize_before: if self.normalize_before:
src = self.norm_mha(src) src = self.norm_mha(src)
@ -260,7 +260,7 @@ class ConformerEncoderLayer(nn.Module):
if not self.normalize_before: if not self.normalize_before:
src = self.norm_mha(src) src = self.norm_mha(src)
# conv1dabs modified attention module # conv1dabs modified attention
residual = src residual = src
if self.normalize_before: if self.normalize_before:
src = self.norm_conv_abs(src) src = self.norm_conv_abs(src)
@ -268,10 +268,11 @@ class ConformerEncoderLayer(nn.Module):
#src = self.linear1(src*0.25) #src = self.linear1(src*0.25)
src = 0.01*self.linear1(src*0.25) src = 0.01*self.linear1(src*0.25)
src = torch.exp(src.clamp(min=-75, max=75)) src = torch.exp(src.clamp(min=-75, max=75))
src = src.permute(1, 2, 0) src = src.permute(1, 2, 0) # (B, D, T)
src = src.permute(0, 2, 1) # (B, T, D)
src = self.conv1d_abs(src) / self.kernel_size src = self.conv1d_abs(src) / self.kernel_size
src = src.permute(2, 0, 1) src = src.permute(2, 0, 1)
src = torch.log(src.clamp(min=1e-20)) src = torch.log(0.01 + src.clamp(min=1e-20))
src = self.linear2(src) src = self.linear2(src)
src = 0.25*self.layernorm(src) src = 0.25*self.layernorm(src)
@ -415,8 +416,8 @@ class RelPositionalEncoding(torch.nn.Module):
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
# Reserve the order of positive indices and concat both positive and # Reserve the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick as in "T # negative indices. This is used to support the shifting trick
# ransformer-XL:Attentive Language Models Beyond a Fixed-Length Context" # as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0) pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1) pe = torch.cat([pe_positive, pe_negative], dim=1)
@ -443,19 +444,14 @@ class RelPositionalEncoding(torch.nn.Module):
] ]
return self.dropout(x), self.dropout(pos_emb) return self.dropout(x), self.dropout(pos_emb)
class RelPositionMultiheadAttention(nn.Module): class RelPositionMultiheadAttention(nn.Module):
r"""Multi-Head Attention layer with relative position encoding r"""Multi-Head Attention layer with relative position encoding
See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
Args: Args:
embed_dim: total dimension of the model. embed_dim: total dimension of the model.
num_heads: parallel attention heads. num_heads: parallel attention heads.
dropout: a Dropout layer on attn_output_weights. Default: 0.0. dropout: a Dropout layer on attn_output_weights. Default: 0.0.
Examples:: Examples::
>>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb)
""" """
@ -517,7 +513,6 @@ class RelPositionMultiheadAttention(nn.Module):
need_weights: output attn_output_weights. need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch. the batches while a 3D mask allows to specify a different mask for the entries of each batch.
Shape: Shape:
- Inputs: - Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
@ -539,7 +534,6 @@ class RelPositionMultiheadAttention(nn.Module):
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight. is provided, it will be added to the attention weight.
- Outputs: - Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension. E is the embedding dimension.
@ -566,11 +560,9 @@ class RelPositionMultiheadAttention(nn.Module):
def rel_shift(self, x: Tensor) -> Tensor: def rel_shift(self, x: Tensor) -> Tensor:
"""Compute relative positional encoding. """Compute relative positional encoding.
Args: Args:
x: Input tensor (batch, head, time1, 2*time1-1). x: Input tensor (batch, head, time1, 2*time1-1).
time1 means the length of query vector. time1 means the length of query vector.
Returns: Returns:
Tensor: tensor of shape (batch, head, time1, time2) Tensor: tensor of shape (batch, head, time1, time2)
(note: time2 has the same value as time1, but it is for (note: time2 has the same value as time1, but it is for
@ -623,7 +615,6 @@ class RelPositionMultiheadAttention(nn.Module):
need_weights: output attn_output_weights. need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch. the batches while a 3D mask allows to specify a different mask for the entries of each batch.
Shape: Shape:
Inputs: Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
@ -645,7 +636,6 @@ class RelPositionMultiheadAttention(nn.Module):
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight. is provided, it will be added to the attention weight.
Outputs: Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension. E is the embedding dimension.
@ -865,7 +855,6 @@ class RelPositionMultiheadAttention(nn.Module):
else: else:
return attn_output, None return attn_output, None
class ConvolutionModule(nn.Module): class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model. """ConvolutionModule in Conformer model.
Modified from Modified from

View File

@ -153,6 +153,19 @@ class _ConvNd(Module):
if not hasattr(self, "padding_mode"): if not hasattr(self, "padding_mode"):
self.padding_mode = "zeros" self.padding_mode = "zeros"
import torch
import torch.nn as nn
m = nn.Tanh()
def padding(input, padding_length):
# input shape : (B, T, D)
device = input.device
B, T, D = input.shape
src = torch.ones(B, T + 2*padding_length[0], D).to(device)
src[:, padding_length[0]:T+padding_length[0], :] = input
src = src.permute(0, 2, 1) # src shape: (B, D, T')
return src
class Conv1dAbs(_ConvNd): class Conv1dAbs(_ConvNd):
def __init__( def __init__(
@ -188,13 +201,14 @@ class Conv1dAbs(_ConvNd):
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
if self.padding_mode != "zeros": if self.padding_mode != "zeros":
return F.conv1d( return F.conv1d(
F.pad( # F.pad(
input, # input,
self._reversed_padding_repeated_twice, # self._reversed_padding_repeated_twice,
mode=self.padding_mode, # mode=self.padding_mode,
), # ),
torch.abs(self.weight), padding(input, self.padding),
torch.abs(self.bias), torch.exp(self.weight),
torch.exp(self.bias),
self.stride, self.stride,
_single(0), _single(0),
self.dilation, self.dilation,