mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 10:44:19 +00:00
modified conv1dabs attention with pad 1
This commit is contained in:
parent
ae667f4801
commit
cbc6f01861
@ -23,6 +23,7 @@ import torch
|
|||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from conv1d_abs_attention import Conv1dAbs
|
from conv1d_abs_attention import Conv1dAbs
|
||||||
from transformer import Supervisions, Transformer, encoder_padding_mask
|
from transformer import Supervisions, Transformer, encoder_padding_mask
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
class Conformer(Transformer):
|
class Conformer(Transformer):
|
||||||
@ -157,6 +158,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(ConformerEncoderLayer, self).__init__()
|
super(ConformerEncoderLayer, self).__init__()
|
||||||
|
|
||||||
self.self_attn = RelPositionMultiheadAttention(
|
self.self_attn = RelPositionMultiheadAttention(
|
||||||
d_model, nhead, dropout=0.0
|
d_model, nhead, dropout=0.0
|
||||||
)
|
)
|
||||||
@ -180,6 +182,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
d_model
|
d_model
|
||||||
) # for the macaron style FNN module
|
) # for the macaron style FNN module
|
||||||
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
||||||
|
|
||||||
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
|
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
|
||||||
|
|
||||||
# define layernorm for conv1d_abs
|
# define layernorm for conv1d_abs
|
||||||
@ -198,16 +201,13 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
self.normalize_before = normalize_before
|
self.normalize_before = normalize_before
|
||||||
|
|
||||||
self.kernel_size = 31
|
self.kernel_size = 31
|
||||||
self.padding = int((self.kernel_size - 1) / 2)
|
self.padding = int((self.kernel_size-1)/2)
|
||||||
self.in_conv1d_channels = 768
|
self.in_conv1d_channels = 768
|
||||||
self.out_conv1d_channels = 768
|
self.out_conv1d_channels = 768
|
||||||
|
# kernel size=21, self.conv1d_channels=768
|
||||||
|
# kernel size=5, self.conv1d_channels=1024
|
||||||
self.linear1 = nn.Linear(512, self.in_conv1d_channels)
|
self.linear1 = nn.Linear(512, self.in_conv1d_channels)
|
||||||
self.conv1d_abs = Conv1dAbs(
|
self.conv1d_abs = Conv1dAbs(self.in_conv1d_channels, self.out_conv1d_channels, kernel_size=self.kernel_size, padding=self.padding, padding_mode="replicate")
|
||||||
self.in_conv1d_channels,
|
|
||||||
self.out_conv1d_channels,
|
|
||||||
kernel_size=self.kernel_size,
|
|
||||||
padding=self.padding,
|
|
||||||
)
|
|
||||||
self.linear2 = nn.Linear(self.out_conv1d_channels, 512)
|
self.linear2 = nn.Linear(self.out_conv1d_channels, 512)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -244,7 +244,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
if not self.normalize_before:
|
if not self.normalize_before:
|
||||||
src = self.norm_ff_macaron(src)
|
src = self.norm_ff_macaron(src)
|
||||||
|
|
||||||
# multi-head attention module
|
# multi-head attention
|
||||||
residual = src
|
residual = src
|
||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
src = self.norm_mha(src)
|
src = self.norm_mha(src)
|
||||||
@ -260,20 +260,21 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
if not self.normalize_before:
|
if not self.normalize_before:
|
||||||
src = self.norm_mha(src)
|
src = self.norm_mha(src)
|
||||||
|
|
||||||
# conv1dabs modified attention module
|
# conv1dabs modified attention
|
||||||
residual = src
|
residual = src
|
||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
src = self.norm_conv_abs(src)
|
src = self.norm_conv_abs(src)
|
||||||
|
|
||||||
# src = self.linear1(src * 0.25)
|
#src = self.linear1(src*0.25)
|
||||||
src = 0.01 * self.linear1(src * 0.25)
|
src = 0.01*self.linear1(src*0.25)
|
||||||
src = torch.exp(src.clamp(min=-75, max=75))
|
src = torch.exp(src.clamp(min=-75, max=75))
|
||||||
src = src.permute(1, 2, 0)
|
src = src.permute(1, 2, 0) # (B, D, T)
|
||||||
|
src = src.permute(0, 2, 1) # (B, T, D)
|
||||||
src = self.conv1d_abs(src) / self.kernel_size
|
src = self.conv1d_abs(src) / self.kernel_size
|
||||||
src = src.permute(2, 0, 1)
|
src = src.permute(2, 0, 1)
|
||||||
src = torch.log(src.clamp(min=1e-20))
|
src = torch.log(0.01 + src.clamp(min=1e-20))
|
||||||
src = self.linear2(src)
|
src = self.linear2(src)
|
||||||
src = 0.25 * self.layernorm(src)
|
src = 0.25*self.layernorm(src)
|
||||||
|
|
||||||
src = residual + self.dropout(src)
|
src = residual + self.dropout(src)
|
||||||
if not self.normalize_before:
|
if not self.normalize_before:
|
||||||
@ -415,8 +416,8 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
||||||
|
|
||||||
# Reserve the order of positive indices and concat both positive and
|
# Reserve the order of positive indices and concat both positive and
|
||||||
# negative indices. This is used to support the shifting trick as in "T
|
# negative indices. This is used to support the shifting trick
|
||||||
# ransformer-XL:Attentive Language Models Beyond a Fixed-Length Context"
|
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||||
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
||||||
pe_negative = pe_negative[1:].unsqueeze(0)
|
pe_negative = pe_negative[1:].unsqueeze(0)
|
||||||
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
||||||
@ -443,19 +444,14 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
]
|
]
|
||||||
return self.dropout(x), self.dropout(pos_emb)
|
return self.dropout(x), self.dropout(pos_emb)
|
||||||
|
|
||||||
|
|
||||||
class RelPositionMultiheadAttention(nn.Module):
|
class RelPositionMultiheadAttention(nn.Module):
|
||||||
r"""Multi-Head Attention layer with relative position encoding
|
r"""Multi-Head Attention layer with relative position encoding
|
||||||
|
|
||||||
See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
embed_dim: total dimension of the model.
|
embed_dim: total dimension of the model.
|
||||||
num_heads: parallel attention heads.
|
num_heads: parallel attention heads.
|
||||||
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
>>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)
|
>>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)
|
||||||
>>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb)
|
>>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb)
|
||||||
"""
|
"""
|
||||||
@ -517,7 +513,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
need_weights: output attn_output_weights.
|
need_weights: output attn_output_weights.
|
||||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||||
|
|
||||||
Shape:
|
Shape:
|
||||||
- Inputs:
|
- Inputs:
|
||||||
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||||
@ -539,7 +534,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
||||||
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
||||||
is provided, it will be added to the attention weight.
|
is provided, it will be added to the attention weight.
|
||||||
|
|
||||||
- Outputs:
|
- Outputs:
|
||||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||||
E is the embedding dimension.
|
E is the embedding dimension.
|
||||||
@ -566,11 +560,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
def rel_shift(self, x: Tensor) -> Tensor:
|
def rel_shift(self, x: Tensor) -> Tensor:
|
||||||
"""Compute relative positional encoding.
|
"""Compute relative positional encoding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: Input tensor (batch, head, time1, 2*time1-1).
|
x: Input tensor (batch, head, time1, 2*time1-1).
|
||||||
time1 means the length of query vector.
|
time1 means the length of query vector.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: tensor of shape (batch, head, time1, time2)
|
Tensor: tensor of shape (batch, head, time1, time2)
|
||||||
(note: time2 has the same value as time1, but it is for
|
(note: time2 has the same value as time1, but it is for
|
||||||
@ -623,7 +615,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
need_weights: output attn_output_weights.
|
need_weights: output attn_output_weights.
|
||||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||||
|
|
||||||
Shape:
|
Shape:
|
||||||
Inputs:
|
Inputs:
|
||||||
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||||
@ -645,7 +636,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
||||||
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
||||||
is provided, it will be added to the attention weight.
|
is provided, it will be added to the attention weight.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||||
E is the embedding dimension.
|
E is the embedding dimension.
|
||||||
@ -865,7 +855,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return attn_output, None
|
return attn_output, None
|
||||||
|
|
||||||
|
|
||||||
class ConvolutionModule(nn.Module):
|
class ConvolutionModule(nn.Module):
|
||||||
"""ConvolutionModule in Conformer model.
|
"""ConvolutionModule in Conformer model.
|
||||||
Modified from
|
Modified from
|
||||||
|
@ -153,6 +153,19 @@ class _ConvNd(Module):
|
|||||||
if not hasattr(self, "padding_mode"):
|
if not hasattr(self, "padding_mode"):
|
||||||
self.padding_mode = "zeros"
|
self.padding_mode = "zeros"
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
m = nn.Tanh()
|
||||||
|
|
||||||
|
def padding(input, padding_length):
|
||||||
|
# input shape : (B, T, D)
|
||||||
|
device = input.device
|
||||||
|
B, T, D = input.shape
|
||||||
|
src = torch.ones(B, T + 2*padding_length[0], D).to(device)
|
||||||
|
src[:, padding_length[0]:T+padding_length[0], :] = input
|
||||||
|
src = src.permute(0, 2, 1) # src shape: (B, D, T')
|
||||||
|
|
||||||
|
return src
|
||||||
|
|
||||||
class Conv1dAbs(_ConvNd):
|
class Conv1dAbs(_ConvNd):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -188,13 +201,14 @@ class Conv1dAbs(_ConvNd):
|
|||||||
def forward(self, input: Tensor) -> Tensor:
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
if self.padding_mode != "zeros":
|
if self.padding_mode != "zeros":
|
||||||
return F.conv1d(
|
return F.conv1d(
|
||||||
F.pad(
|
# F.pad(
|
||||||
input,
|
# input,
|
||||||
self._reversed_padding_repeated_twice,
|
# self._reversed_padding_repeated_twice,
|
||||||
mode=self.padding_mode,
|
# mode=self.padding_mode,
|
||||||
),
|
# ),
|
||||||
torch.abs(self.weight),
|
padding(input, self.padding),
|
||||||
torch.abs(self.bias),
|
torch.exp(self.weight),
|
||||||
|
torch.exp(self.bias),
|
||||||
self.stride,
|
self.stride,
|
||||||
_single(0),
|
_single(0),
|
||||||
self.dilation,
|
self.dilation,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user