mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Some experiments with modified attention
This commit is contained in:
parent
5d314b03c5
commit
e442369987
@ -50,7 +50,7 @@ class Conformer(Transformer):
|
|||||||
d_model: int = 256,
|
d_model: int = 256,
|
||||||
nhead: int = 4,
|
nhead: int = 4,
|
||||||
dim_feedforward: int = 2048,
|
dim_feedforward: int = 2048,
|
||||||
num_encoder_layers: int = 12,
|
num_encoder_layers: int = 12, # 12
|
||||||
num_decoder_layers: int = 6,
|
num_decoder_layers: int = 6,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
@ -118,6 +118,7 @@ class Conformer(Transformer):
|
|||||||
mask = encoder_padding_mask(x.size(0), supervisions)
|
mask = encoder_padding_mask(x.size(0), supervisions)
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
mask = mask.to(x.device)
|
mask = mask.to(x.device)
|
||||||
|
# print("encoder input: ", x[0][0][:20])
|
||||||
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
|
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
|
||||||
|
|
||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
@ -156,9 +157,6 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(ConformerEncoderLayer, self).__init__()
|
super(ConformerEncoderLayer, self).__init__()
|
||||||
self.self_attn = RelPositionMultiheadAttention(
|
|
||||||
d_model, nhead, dropout=0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
self.feed_forward = nn.Sequential(
|
self.feed_forward = nn.Sequential(
|
||||||
nn.Linear(d_model, dim_feedforward),
|
nn.Linear(d_model, dim_feedforward),
|
||||||
@ -180,7 +178,9 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
d_model
|
d_model
|
||||||
) # for the macaron style FNN module
|
) # for the macaron style FNN module
|
||||||
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
||||||
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
|
|
||||||
|
#define layernorm for conv1d_abs
|
||||||
|
self.norm_conv_abs = nn.LayerNorm(d_model)
|
||||||
|
|
||||||
self.ff_scale = 0.5
|
self.ff_scale = 0.5
|
||||||
|
|
||||||
@ -193,6 +193,10 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.normalize_before = normalize_before
|
self.normalize_before = normalize_before
|
||||||
|
|
||||||
|
self.linear1 = nn.Linear(512, 1024)
|
||||||
|
self.conv1d_abs = ConvolutionModule_abs(1024, 64, kernel_size=21, padding=10)
|
||||||
|
self.linear2 = nn.Linear(64, 512)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
@ -227,21 +231,18 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
if not self.normalize_before:
|
if not self.normalize_before:
|
||||||
src = self.norm_ff_macaron(src)
|
src = self.norm_ff_macaron(src)
|
||||||
|
|
||||||
# multi-headed self-attention module
|
# modified-attention module
|
||||||
residual = src
|
residual = src
|
||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
src = self.norm_mha(src)
|
src = self.norm_conv_abs(src)
|
||||||
src_att = self.self_attn(
|
src = self.linear1(src)
|
||||||
src,
|
src = torch.exp(src.clamp(max=75))
|
||||||
src,
|
src = self.conv1d_abs(src)
|
||||||
src,
|
src = torch.log(src)
|
||||||
pos_emb=pos_emb,
|
src = self.linear2(src)
|
||||||
attn_mask=src_mask,
|
src = residual + self.dropout(src)
|
||||||
key_padding_mask=src_key_padding_mask,
|
|
||||||
)[0]
|
|
||||||
src = residual + self.dropout(src_att)
|
|
||||||
if not self.normalize_before:
|
if not self.normalize_before:
|
||||||
src = self.norm_mha(src)
|
src = self.norm_conv_abs(src)
|
||||||
|
|
||||||
# convolution module
|
# convolution module
|
||||||
residual = src
|
residual = src
|
||||||
@ -408,429 +409,6 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
]
|
]
|
||||||
return self.dropout(x), self.dropout(pos_emb)
|
return self.dropout(x), self.dropout(pos_emb)
|
||||||
|
|
||||||
|
|
||||||
class RelPositionMultiheadAttention(nn.Module):
|
|
||||||
r"""Multi-Head Attention layer with relative position encoding
|
|
||||||
|
|
||||||
See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
|
||||||
|
|
||||||
Args:
|
|
||||||
embed_dim: total dimension of the model.
|
|
||||||
num_heads: parallel attention heads.
|
|
||||||
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
|
||||||
|
|
||||||
Examples::
|
|
||||||
|
|
||||||
>>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)
|
|
||||||
>>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
embed_dim: int,
|
|
||||||
num_heads: int,
|
|
||||||
dropout: float = 0.0,
|
|
||||||
) -> None:
|
|
||||||
super(RelPositionMultiheadAttention, self).__init__()
|
|
||||||
self.embed_dim = embed_dim
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.dropout = dropout
|
|
||||||
self.head_dim = embed_dim // num_heads
|
|
||||||
assert (
|
|
||||||
self.head_dim * num_heads == self.embed_dim
|
|
||||||
), "embed_dim must be divisible by num_heads"
|
|
||||||
|
|
||||||
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
|
|
||||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
|
||||||
|
|
||||||
# linear transformation for positional encoding.
|
|
||||||
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
|
|
||||||
# these two learnable bias are used in matrix c and matrix d
|
|
||||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
|
||||||
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
|
||||||
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
|
||||||
|
|
||||||
self._reset_parameters()
|
|
||||||
|
|
||||||
def _reset_parameters(self) -> None:
|
|
||||||
nn.init.xavier_uniform_(self.in_proj.weight)
|
|
||||||
nn.init.constant_(self.in_proj.bias, 0.0)
|
|
||||||
nn.init.constant_(self.out_proj.bias, 0.0)
|
|
||||||
|
|
||||||
nn.init.xavier_uniform_(self.pos_bias_u)
|
|
||||||
nn.init.xavier_uniform_(self.pos_bias_v)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
query: Tensor,
|
|
||||||
key: Tensor,
|
|
||||||
value: Tensor,
|
|
||||||
pos_emb: Tensor,
|
|
||||||
key_padding_mask: Optional[Tensor] = None,
|
|
||||||
need_weights: bool = True,
|
|
||||||
attn_mask: Optional[Tensor] = None,
|
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
|
||||||
r"""
|
|
||||||
Args:
|
|
||||||
query, key, value: map a query and a set of key-value pairs to an output.
|
|
||||||
pos_emb: Positional embedding tensor
|
|
||||||
key_padding_mask: if provided, specified padding elements in the key will
|
|
||||||
be ignored by the attention. When given a binary mask and a value is True,
|
|
||||||
the corresponding value on the attention layer will be ignored. When given
|
|
||||||
a byte mask and a value is non-zero, the corresponding value on the attention
|
|
||||||
layer will be ignored
|
|
||||||
need_weights: output attn_output_weights.
|
|
||||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
|
||||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
|
||||||
|
|
||||||
Shape:
|
|
||||||
- Inputs:
|
|
||||||
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
|
||||||
the embedding dimension.
|
|
||||||
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
|
||||||
the embedding dimension.
|
|
||||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
|
||||||
the embedding dimension.
|
|
||||||
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
|
|
||||||
the embedding dimension.
|
|
||||||
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
|
||||||
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
|
||||||
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
|
||||||
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
|
||||||
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
|
||||||
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
|
||||||
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
|
||||||
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
|
||||||
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
|
||||||
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
|
||||||
is provided, it will be added to the attention weight.
|
|
||||||
|
|
||||||
- Outputs:
|
|
||||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
|
||||||
E is the embedding dimension.
|
|
||||||
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
|
||||||
L is the target sequence length, S is the source sequence length.
|
|
||||||
"""
|
|
||||||
return self.multi_head_attention_forward(
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
pos_emb,
|
|
||||||
self.embed_dim,
|
|
||||||
self.num_heads,
|
|
||||||
self.in_proj.weight,
|
|
||||||
self.in_proj.bias,
|
|
||||||
self.dropout,
|
|
||||||
self.out_proj.weight,
|
|
||||||
self.out_proj.bias,
|
|
||||||
training=self.training,
|
|
||||||
key_padding_mask=key_padding_mask,
|
|
||||||
need_weights=need_weights,
|
|
||||||
attn_mask=attn_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
def rel_shift(self, x: Tensor) -> Tensor:
|
|
||||||
"""Compute relative positional encoding.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x: Input tensor (batch, head, time1, 2*time1-1).
|
|
||||||
time1 means the length of query vector.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: tensor of shape (batch, head, time1, time2)
|
|
||||||
(note: time2 has the same value as time1, but it is for
|
|
||||||
the key, while time1 is for the query).
|
|
||||||
"""
|
|
||||||
(batch_size, num_heads, time1, n) = x.shape
|
|
||||||
assert n == 2 * time1 - 1
|
|
||||||
# Note: TorchScript requires explicit arg for stride()
|
|
||||||
batch_stride = x.stride(0)
|
|
||||||
head_stride = x.stride(1)
|
|
||||||
time1_stride = x.stride(2)
|
|
||||||
n_stride = x.stride(3)
|
|
||||||
return x.as_strided(
|
|
||||||
(batch_size, num_heads, time1, time1),
|
|
||||||
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
|
|
||||||
storage_offset=n_stride * (time1 - 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
def multi_head_attention_forward(
|
|
||||||
self,
|
|
||||||
query: Tensor,
|
|
||||||
key: Tensor,
|
|
||||||
value: Tensor,
|
|
||||||
pos_emb: Tensor,
|
|
||||||
embed_dim_to_check: int,
|
|
||||||
num_heads: int,
|
|
||||||
in_proj_weight: Tensor,
|
|
||||||
in_proj_bias: Tensor,
|
|
||||||
dropout_p: float,
|
|
||||||
out_proj_weight: Tensor,
|
|
||||||
out_proj_bias: Tensor,
|
|
||||||
training: bool = True,
|
|
||||||
key_padding_mask: Optional[Tensor] = None,
|
|
||||||
need_weights: bool = True,
|
|
||||||
attn_mask: Optional[Tensor] = None,
|
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
|
||||||
r"""
|
|
||||||
Args:
|
|
||||||
query, key, value: map a query and a set of key-value pairs to an output.
|
|
||||||
pos_emb: Positional embedding tensor
|
|
||||||
embed_dim_to_check: total dimension of the model.
|
|
||||||
num_heads: parallel attention heads.
|
|
||||||
in_proj_weight, in_proj_bias: input projection weight and bias.
|
|
||||||
dropout_p: probability of an element to be zeroed.
|
|
||||||
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
|
||||||
training: apply dropout if is ``True``.
|
|
||||||
key_padding_mask: if provided, specified padding elements in the key will
|
|
||||||
be ignored by the attention. This is an binary mask. When the value is True,
|
|
||||||
the corresponding value on the attention layer will be filled with -inf.
|
|
||||||
need_weights: output attn_output_weights.
|
|
||||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
|
||||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
|
||||||
|
|
||||||
Shape:
|
|
||||||
Inputs:
|
|
||||||
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
|
||||||
the embedding dimension.
|
|
||||||
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
|
||||||
the embedding dimension.
|
|
||||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
|
||||||
the embedding dimension.
|
|
||||||
- pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
|
|
||||||
length, N is the batch size, E is the embedding dimension.
|
|
||||||
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
|
||||||
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
|
|
||||||
will be unchanged. If a BoolTensor is provided, the positions with the
|
|
||||||
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
|
||||||
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
|
||||||
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
|
||||||
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
|
|
||||||
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
|
||||||
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
|
||||||
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
|
||||||
is provided, it will be added to the attention weight.
|
|
||||||
|
|
||||||
Outputs:
|
|
||||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
|
||||||
E is the embedding dimension.
|
|
||||||
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
|
||||||
L is the target sequence length, S is the source sequence length.
|
|
||||||
"""
|
|
||||||
|
|
||||||
tgt_len, bsz, embed_dim = query.size()
|
|
||||||
assert embed_dim == embed_dim_to_check
|
|
||||||
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
|
||||||
|
|
||||||
head_dim = embed_dim // num_heads
|
|
||||||
assert (
|
|
||||||
head_dim * num_heads == embed_dim
|
|
||||||
), "embed_dim must be divisible by num_heads"
|
|
||||||
scaling = float(head_dim) ** -0.5
|
|
||||||
|
|
||||||
if torch.equal(query, key) and torch.equal(key, value):
|
|
||||||
# self-attention
|
|
||||||
q, k, v = nn.functional.linear(
|
|
||||||
query, in_proj_weight, in_proj_bias
|
|
||||||
).chunk(3, dim=-1)
|
|
||||||
|
|
||||||
elif torch.equal(key, value):
|
|
||||||
# encoder-decoder attention
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
|
||||||
_b = in_proj_bias
|
|
||||||
_start = 0
|
|
||||||
_end = embed_dim
|
|
||||||
_w = in_proj_weight[_start:_end, :]
|
|
||||||
if _b is not None:
|
|
||||||
_b = _b[_start:_end]
|
|
||||||
q = nn.functional.linear(query, _w, _b)
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
|
||||||
_b = in_proj_bias
|
|
||||||
_start = embed_dim
|
|
||||||
_end = None
|
|
||||||
_w = in_proj_weight[_start:, :]
|
|
||||||
if _b is not None:
|
|
||||||
_b = _b[_start:]
|
|
||||||
k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
|
||||||
_b = in_proj_bias
|
|
||||||
_start = 0
|
|
||||||
_end = embed_dim
|
|
||||||
_w = in_proj_weight[_start:_end, :]
|
|
||||||
if _b is not None:
|
|
||||||
_b = _b[_start:_end]
|
|
||||||
q = nn.functional.linear(query, _w, _b)
|
|
||||||
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
|
||||||
_b = in_proj_bias
|
|
||||||
_start = embed_dim
|
|
||||||
_end = embed_dim * 2
|
|
||||||
_w = in_proj_weight[_start:_end, :]
|
|
||||||
if _b is not None:
|
|
||||||
_b = _b[_start:_end]
|
|
||||||
k = nn.functional.linear(key, _w, _b)
|
|
||||||
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
|
||||||
_b = in_proj_bias
|
|
||||||
_start = embed_dim * 2
|
|
||||||
_end = None
|
|
||||||
_w = in_proj_weight[_start:, :]
|
|
||||||
if _b is not None:
|
|
||||||
_b = _b[_start:]
|
|
||||||
v = nn.functional.linear(value, _w, _b)
|
|
||||||
|
|
||||||
if attn_mask is not None:
|
|
||||||
assert (
|
|
||||||
attn_mask.dtype == torch.float32
|
|
||||||
or attn_mask.dtype == torch.float64
|
|
||||||
or attn_mask.dtype == torch.float16
|
|
||||||
or attn_mask.dtype == torch.uint8
|
|
||||||
or attn_mask.dtype == torch.bool
|
|
||||||
), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
|
|
||||||
attn_mask.dtype
|
|
||||||
)
|
|
||||||
if attn_mask.dtype == torch.uint8:
|
|
||||||
warnings.warn(
|
|
||||||
"Byte tensor for attn_mask is deprecated. Use bool tensor instead."
|
|
||||||
)
|
|
||||||
attn_mask = attn_mask.to(torch.bool)
|
|
||||||
|
|
||||||
if attn_mask.dim() == 2:
|
|
||||||
attn_mask = attn_mask.unsqueeze(0)
|
|
||||||
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
|
||||||
raise RuntimeError(
|
|
||||||
"The size of the 2D attn_mask is not correct."
|
|
||||||
)
|
|
||||||
elif attn_mask.dim() == 3:
|
|
||||||
if list(attn_mask.size()) != [
|
|
||||||
bsz * num_heads,
|
|
||||||
query.size(0),
|
|
||||||
key.size(0),
|
|
||||||
]:
|
|
||||||
raise RuntimeError(
|
|
||||||
"The size of the 3D attn_mask is not correct."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
"attn_mask's dimension {} is not supported".format(
|
|
||||||
attn_mask.dim()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# attn_mask's dim is 3 now.
|
|
||||||
|
|
||||||
# convert ByteTensor key_padding_mask to bool
|
|
||||||
if (
|
|
||||||
key_padding_mask is not None
|
|
||||||
and key_padding_mask.dtype == torch.uint8
|
|
||||||
):
|
|
||||||
warnings.warn(
|
|
||||||
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
|
|
||||||
)
|
|
||||||
key_padding_mask = key_padding_mask.to(torch.bool)
|
|
||||||
|
|
||||||
q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim)
|
|
||||||
k = k.contiguous().view(-1, bsz, num_heads, head_dim)
|
|
||||||
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
|
||||||
|
|
||||||
src_len = k.size(0)
|
|
||||||
|
|
||||||
if key_padding_mask is not None:
|
|
||||||
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
|
|
||||||
key_padding_mask.size(0), bsz
|
|
||||||
)
|
|
||||||
assert key_padding_mask.size(1) == src_len, "{} == {}".format(
|
|
||||||
key_padding_mask.size(1), src_len
|
|
||||||
)
|
|
||||||
|
|
||||||
q = q.transpose(0, 1) # (batch, time1, head, d_k)
|
|
||||||
|
|
||||||
pos_emb_bsz = pos_emb.size(0)
|
|
||||||
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
|
||||||
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
|
||||||
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
|
||||||
|
|
||||||
q_with_bias_u = (q + self.pos_bias_u).transpose(
|
|
||||||
1, 2
|
|
||||||
) # (batch, head, time1, d_k)
|
|
||||||
|
|
||||||
q_with_bias_v = (q + self.pos_bias_v).transpose(
|
|
||||||
1, 2
|
|
||||||
) # (batch, head, time1, d_k)
|
|
||||||
|
|
||||||
# compute attention score
|
|
||||||
# first compute matrix a and matrix c
|
|
||||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
|
||||||
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
|
|
||||||
matrix_ac = torch.matmul(
|
|
||||||
q_with_bias_u, k
|
|
||||||
) # (batch, head, time1, time2)
|
|
||||||
|
|
||||||
# compute matrix b and matrix d
|
|
||||||
matrix_bd = torch.matmul(
|
|
||||||
q_with_bias_v, p.transpose(-2, -1)
|
|
||||||
) # (batch, head, time1, 2*time1-1)
|
|
||||||
matrix_bd = self.rel_shift(matrix_bd)
|
|
||||||
|
|
||||||
attn_output_weights = (
|
|
||||||
matrix_ac + matrix_bd
|
|
||||||
) * scaling # (batch, head, time1, time2)
|
|
||||||
|
|
||||||
attn_output_weights = attn_output_weights.view(
|
|
||||||
bsz * num_heads, tgt_len, -1
|
|
||||||
)
|
|
||||||
|
|
||||||
assert list(attn_output_weights.size()) == [
|
|
||||||
bsz * num_heads,
|
|
||||||
tgt_len,
|
|
||||||
src_len,
|
|
||||||
]
|
|
||||||
|
|
||||||
if attn_mask is not None:
|
|
||||||
if attn_mask.dtype == torch.bool:
|
|
||||||
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
|
|
||||||
else:
|
|
||||||
attn_output_weights += attn_mask
|
|
||||||
|
|
||||||
if key_padding_mask is not None:
|
|
||||||
attn_output_weights = attn_output_weights.view(
|
|
||||||
bsz, num_heads, tgt_len, src_len
|
|
||||||
)
|
|
||||||
attn_output_weights = attn_output_weights.masked_fill(
|
|
||||||
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
|
||||||
float("-inf"),
|
|
||||||
)
|
|
||||||
attn_output_weights = attn_output_weights.view(
|
|
||||||
bsz * num_heads, tgt_len, src_len
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
|
|
||||||
attn_output_weights = nn.functional.dropout(
|
|
||||||
attn_output_weights, p=dropout_p, training=training
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = torch.bmm(attn_output_weights, v)
|
|
||||||
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
|
||||||
attn_output = (
|
|
||||||
attn_output.transpose(0, 1)
|
|
||||||
.contiguous()
|
|
||||||
.view(tgt_len, bsz, embed_dim)
|
|
||||||
)
|
|
||||||
attn_output = nn.functional.linear(
|
|
||||||
attn_output, out_proj_weight, out_proj_bias
|
|
||||||
)
|
|
||||||
|
|
||||||
if need_weights:
|
|
||||||
# average attention weights over heads
|
|
||||||
attn_output_weights = attn_output_weights.view(
|
|
||||||
bsz, num_heads, tgt_len, src_len
|
|
||||||
)
|
|
||||||
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
|
||||||
else:
|
|
||||||
return attn_output, None
|
|
||||||
|
|
||||||
|
|
||||||
class ConvolutionModule(nn.Module):
|
class ConvolutionModule(nn.Module):
|
||||||
"""ConvolutionModule in Conformer model.
|
"""ConvolutionModule in Conformer model.
|
||||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
|
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
|
||||||
@ -904,6 +482,54 @@ class ConvolutionModule(nn.Module):
|
|||||||
return x.permute(2, 0, 1)
|
return x.permute(2, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class ConvolutionModule_abs(nn.Module):
|
||||||
|
"""ConvolutionModule in Conformer model.
|
||||||
|
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels (int): The number of channels of conv layers.
|
||||||
|
kernel_size (int): Kernerl size of conv layers.
|
||||||
|
bias (bool): Whether to use bias in conv layers (default=True).
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, channels: int, out_channels: int, kernel_size: int, padding: int, bias: bool = True
|
||||||
|
) -> None:
|
||||||
|
"""Construct an ConvolutionModule object."""
|
||||||
|
super(ConvolutionModule_abs, self).__init__()
|
||||||
|
# kernerl_size should be a odd number for 'SAME' padding
|
||||||
|
assert (kernel_size - 1) % 2 == 0
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv1d_abs(
|
||||||
|
channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=padding,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.activation = nn.ReLU()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
"""Compute convolution module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor (#time, batch, channels).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor (#time, batch, channels).
|
||||||
|
|
||||||
|
"""
|
||||||
|
# exchange the temporal dimension and the feature dimension
|
||||||
|
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
|
||||||
|
return x.permute(2, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
class Swish(torch.nn.Module):
|
class Swish(torch.nn.Module):
|
||||||
"""Construct an Swish object."""
|
"""Construct an Swish object."""
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user