mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Operation for feedforward+mha+conv1dabs+conv+macaron.
This commit is contained in:
parent
309461c185
commit
99274cbb8f
@ -157,7 +157,9 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(ConformerEncoderLayer, self).__init__()
|
super(ConformerEncoderLayer, self).__init__()
|
||||||
|
self.self_attn = RelPositionMultiheadAttention(
|
||||||
|
d_model, nhead, dropout=0.0
|
||||||
|
)
|
||||||
self.feed_forward = nn.Sequential(
|
self.feed_forward = nn.Sequential(
|
||||||
nn.Linear(d_model, dim_feedforward),
|
nn.Linear(d_model, dim_feedforward),
|
||||||
Swish(),
|
Swish(),
|
||||||
@ -178,6 +180,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
d_model
|
d_model
|
||||||
) # for the macaron style FNN module
|
) # for the macaron style FNN module
|
||||||
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
||||||
|
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
|
||||||
|
|
||||||
# define layernorm for conv1d_abs
|
# define layernorm for conv1d_abs
|
||||||
self.norm_conv_abs = nn.LayerNorm(d_model)
|
self.norm_conv_abs = nn.LayerNorm(d_model)
|
||||||
@ -194,7 +197,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.normalize_before = normalize_before
|
self.normalize_before = normalize_before
|
||||||
|
|
||||||
self.kernel_size = 21
|
self.kernel_size = 31
|
||||||
self.padding = int((self.kernel_size - 1) / 2)
|
self.padding = int((self.kernel_size - 1) / 2)
|
||||||
self.conv1d_channels = 768
|
self.conv1d_channels = 768
|
||||||
self.linear1 = nn.Linear(512, self.conv1d_channels)
|
self.linear1 = nn.Linear(512, self.conv1d_channels)
|
||||||
@ -240,19 +243,31 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
if not self.normalize_before:
|
if not self.normalize_before:
|
||||||
src = self.norm_ff_macaron(src)
|
src = self.norm_ff_macaron(src)
|
||||||
|
|
||||||
inf = torch.tensor(float("inf"), device=src.device)
|
# multi-head attention module
|
||||||
|
residual = src
|
||||||
|
if self.normalize_before:
|
||||||
|
src = self.norm_mha(src)
|
||||||
|
src_att = self.self_attn(
|
||||||
|
src,
|
||||||
|
src,
|
||||||
|
src,
|
||||||
|
pos_emb=pos_emb,
|
||||||
|
attn_mask=src_mask,
|
||||||
|
key_padding_mask=src_key_padding_mask,
|
||||||
|
)[0]
|
||||||
|
src = residual + self.dropout(src_att)
|
||||||
|
if not self.normalize_before:
|
||||||
|
src = self.norm_mha(src)
|
||||||
|
|
||||||
|
# conv1dabs modified attention module
|
||||||
|
inf = torch.tensor(float("inf"), device=src.device)
|
||||||
def check_inf(x):
|
def check_inf(x):
|
||||||
if x.max() == inf:
|
if x.max() == inf:
|
||||||
print("Error: inf found: ", x)
|
print("Error: inf found: ", x)
|
||||||
assert 0
|
assert 0
|
||||||
|
|
||||||
# modified-attention module
|
|
||||||
|
|
||||||
residual = src
|
residual = src
|
||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
src = self.norm_conv_abs(src)
|
src = self.norm_conv_abs(src)
|
||||||
|
|
||||||
src = self.linear1(src * 0.25)
|
src = self.linear1(src * 0.25)
|
||||||
src = torch.exp(src.clamp(min=-75, max=75))
|
src = torch.exp(src.clamp(min=-75, max=75))
|
||||||
check_inf(src)
|
check_inf(src)
|
||||||
@ -266,7 +281,6 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
src = self.layernorm(src)
|
src = self.layernorm(src)
|
||||||
# multipy the output by 0.5 later.
|
# multipy the output by 0.5 later.
|
||||||
# do a comparison.
|
# do a comparison.
|
||||||
|
|
||||||
src = residual + self.dropout(src)
|
src = residual + self.dropout(src)
|
||||||
if not self.normalize_before:
|
if not self.normalize_before:
|
||||||
src = self.norm_conv_abs(src)
|
src = self.norm_conv_abs(src)
|
||||||
@ -436,6 +450,427 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
return self.dropout(x), self.dropout(pos_emb)
|
return self.dropout(x), self.dropout(pos_emb)
|
||||||
|
|
||||||
|
|
||||||
|
class RelPositionMultiheadAttention(nn.Module):
|
||||||
|
r"""Multi-Head Attention layer with relative position encoding
|
||||||
|
|
||||||
|
See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embed_dim: total dimension of the model.
|
||||||
|
num_heads: parallel attention heads.
|
||||||
|
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)
|
||||||
|
>>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
) -> None:
|
||||||
|
super(RelPositionMultiheadAttention, self).__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.dropout = dropout
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
assert (
|
||||||
|
self.head_dim * num_heads == self.embed_dim
|
||||||
|
), "embed_dim must be divisible by num_heads"
|
||||||
|
|
||||||
|
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
|
||||||
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
||||||
|
|
||||||
|
# linear transformation for positional encoding.
|
||||||
|
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
|
||||||
|
# these two learnable bias are used in matrix c and matrix d
|
||||||
|
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||||
|
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||||
|
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||||
|
|
||||||
|
self._reset_parameters()
|
||||||
|
|
||||||
|
def _reset_parameters(self) -> None:
|
||||||
|
nn.init.xavier_uniform_(self.in_proj.weight)
|
||||||
|
nn.init.constant_(self.in_proj.bias, 0.0)
|
||||||
|
nn.init.constant_(self.out_proj.bias, 0.0)
|
||||||
|
|
||||||
|
nn.init.xavier_uniform_(self.pos_bias_u)
|
||||||
|
nn.init.xavier_uniform_(self.pos_bias_v)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: Tensor,
|
||||||
|
key: Tensor,
|
||||||
|
value: Tensor,
|
||||||
|
pos_emb: Tensor,
|
||||||
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
|
need_weights: bool = True,
|
||||||
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
query, key, value: map a query and a set of key-value pairs to an output.
|
||||||
|
pos_emb: Positional embedding tensor
|
||||||
|
key_padding_mask: if provided, specified padding elements in the key will
|
||||||
|
be ignored by the attention. When given a binary mask and a value is True,
|
||||||
|
the corresponding value on the attention layer will be ignored. When given
|
||||||
|
a byte mask and a value is non-zero, the corresponding value on the attention
|
||||||
|
layer will be ignored
|
||||||
|
need_weights: output attn_output_weights.
|
||||||
|
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||||
|
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
- Inputs:
|
||||||
|
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||||
|
the embedding dimension.
|
||||||
|
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
||||||
|
the embedding dimension.
|
||||||
|
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||||
|
the embedding dimension.
|
||||||
|
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
|
||||||
|
the embedding dimension.
|
||||||
|
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
||||||
|
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
||||||
|
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
||||||
|
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
||||||
|
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
||||||
|
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
||||||
|
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
||||||
|
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
||||||
|
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
||||||
|
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
||||||
|
is provided, it will be added to the attention weight.
|
||||||
|
|
||||||
|
- Outputs:
|
||||||
|
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||||
|
E is the embedding dimension.
|
||||||
|
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
||||||
|
L is the target sequence length, S is the source sequence length.
|
||||||
|
"""
|
||||||
|
return self.multi_head_attention_forward(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
pos_emb,
|
||||||
|
self.embed_dim,
|
||||||
|
self.num_heads,
|
||||||
|
self.in_proj.weight,
|
||||||
|
self.in_proj.bias,
|
||||||
|
self.dropout,
|
||||||
|
self.out_proj.weight,
|
||||||
|
self.out_proj.bias,
|
||||||
|
training=self.training,
|
||||||
|
key_padding_mask=key_padding_mask,
|
||||||
|
need_weights=need_weights,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
def rel_shift(self, x: Tensor) -> Tensor:
|
||||||
|
"""Compute relative positional encoding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor (batch, head, time1, 2*time1-1).
|
||||||
|
time1 means the length of query vector.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: tensor of shape (batch, head, time1, time2)
|
||||||
|
(note: time2 has the same value as time1, but it is for
|
||||||
|
the key, while time1 is for the query).
|
||||||
|
"""
|
||||||
|
(batch_size, num_heads, time1, n) = x.shape
|
||||||
|
assert n == 2 * time1 - 1
|
||||||
|
# Note: TorchScript requires explicit arg for stride()
|
||||||
|
batch_stride = x.stride(0)
|
||||||
|
head_stride = x.stride(1)
|
||||||
|
time1_stride = x.stride(2)
|
||||||
|
n_stride = x.stride(3)
|
||||||
|
return x.as_strided(
|
||||||
|
(batch_size, num_heads, time1, time1),
|
||||||
|
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
|
||||||
|
storage_offset=n_stride * (time1 - 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def multi_head_attention_forward(
|
||||||
|
self,
|
||||||
|
query: Tensor,
|
||||||
|
key: Tensor,
|
||||||
|
value: Tensor,
|
||||||
|
pos_emb: Tensor,
|
||||||
|
embed_dim_to_check: int,
|
||||||
|
num_heads: int,
|
||||||
|
in_proj_weight: Tensor,
|
||||||
|
in_proj_bias: Tensor,
|
||||||
|
dropout_p: float,
|
||||||
|
out_proj_weight: Tensor,
|
||||||
|
out_proj_bias: Tensor,
|
||||||
|
training: bool = True,
|
||||||
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
|
need_weights: bool = True,
|
||||||
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
query, key, value: map a query and a set of key-value pairs to an output.
|
||||||
|
pos_emb: Positional embedding tensor
|
||||||
|
embed_dim_to_check: total dimension of the model.
|
||||||
|
num_heads: parallel attention heads.
|
||||||
|
in_proj_weight, in_proj_bias: input projection weight and bias.
|
||||||
|
dropout_p: probability of an element to be zeroed.
|
||||||
|
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
||||||
|
training: apply dropout if is ``True``.
|
||||||
|
key_padding_mask: if provided, specified padding elements in the key will
|
||||||
|
be ignored by the attention. This is an binary mask. When the value is True,
|
||||||
|
the corresponding value on the attention layer will be filled with -inf.
|
||||||
|
need_weights: output attn_output_weights.
|
||||||
|
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||||
|
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
Inputs:
|
||||||
|
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||||
|
the embedding dimension.
|
||||||
|
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
||||||
|
the embedding dimension.
|
||||||
|
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||||
|
the embedding dimension.
|
||||||
|
- pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
|
||||||
|
length, N is the batch size, E is the embedding dimension.
|
||||||
|
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
||||||
|
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
|
||||||
|
will be unchanged. If a BoolTensor is provided, the positions with the
|
||||||
|
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
||||||
|
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
||||||
|
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
||||||
|
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
|
||||||
|
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
||||||
|
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
||||||
|
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
||||||
|
is provided, it will be added to the attention weight.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||||
|
E is the embedding dimension.
|
||||||
|
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
||||||
|
L is the target sequence length, S is the source sequence length.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tgt_len, bsz, embed_dim = query.size()
|
||||||
|
assert embed_dim == embed_dim_to_check
|
||||||
|
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
||||||
|
|
||||||
|
head_dim = embed_dim // num_heads
|
||||||
|
assert (
|
||||||
|
head_dim * num_heads == embed_dim
|
||||||
|
), "embed_dim must be divisible by num_heads"
|
||||||
|
scaling = float(head_dim) ** -0.5
|
||||||
|
|
||||||
|
if torch.equal(query, key) and torch.equal(key, value):
|
||||||
|
# self-attention
|
||||||
|
q, k, v = nn.functional.linear(
|
||||||
|
query, in_proj_weight, in_proj_bias
|
||||||
|
).chunk(3, dim=-1)
|
||||||
|
|
||||||
|
elif torch.equal(key, value):
|
||||||
|
# encoder-decoder attention
|
||||||
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
|
_b = in_proj_bias
|
||||||
|
_start = 0
|
||||||
|
_end = embed_dim
|
||||||
|
_w = in_proj_weight[_start:_end, :]
|
||||||
|
if _b is not None:
|
||||||
|
_b = _b[_start:_end]
|
||||||
|
q = nn.functional.linear(query, _w, _b)
|
||||||
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
|
_b = in_proj_bias
|
||||||
|
_start = embed_dim
|
||||||
|
_end = None
|
||||||
|
_w = in_proj_weight[_start:, :]
|
||||||
|
if _b is not None:
|
||||||
|
_b = _b[_start:]
|
||||||
|
k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
|
_b = in_proj_bias
|
||||||
|
_start = 0
|
||||||
|
_end = embed_dim
|
||||||
|
_w = in_proj_weight[_start:_end, :]
|
||||||
|
if _b is not None:
|
||||||
|
_b = _b[_start:_end]
|
||||||
|
q = nn.functional.linear(query, _w, _b)
|
||||||
|
|
||||||
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
|
_b = in_proj_bias
|
||||||
|
_start = embed_dim
|
||||||
|
_end = embed_dim * 2
|
||||||
|
_w = in_proj_weight[_start:_end, :]
|
||||||
|
if _b is not None:
|
||||||
|
_b = _b[_start:_end]
|
||||||
|
k = nn.functional.linear(key, _w, _b)
|
||||||
|
|
||||||
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
|
_b = in_proj_bias
|
||||||
|
_start = embed_dim * 2
|
||||||
|
_end = None
|
||||||
|
_w = in_proj_weight[_start:, :]
|
||||||
|
if _b is not None:
|
||||||
|
_b = _b[_start:]
|
||||||
|
v = nn.functional.linear(value, _w, _b)
|
||||||
|
|
||||||
|
if attn_mask is not None:
|
||||||
|
assert (
|
||||||
|
attn_mask.dtype == torch.float32
|
||||||
|
or attn_mask.dtype == torch.float64
|
||||||
|
or attn_mask.dtype == torch.float16
|
||||||
|
or attn_mask.dtype == torch.uint8
|
||||||
|
or attn_mask.dtype == torch.bool
|
||||||
|
), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
|
||||||
|
attn_mask.dtype
|
||||||
|
)
|
||||||
|
if attn_mask.dtype == torch.uint8:
|
||||||
|
warnings.warn(
|
||||||
|
"Byte tensor for attn_mask is deprecated. Use bool tensor instead."
|
||||||
|
)
|
||||||
|
attn_mask = attn_mask.to(torch.bool)
|
||||||
|
|
||||||
|
if attn_mask.dim() == 2:
|
||||||
|
attn_mask = attn_mask.unsqueeze(0)
|
||||||
|
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
||||||
|
raise RuntimeError(
|
||||||
|
"The size of the 2D attn_mask is not correct."
|
||||||
|
)
|
||||||
|
elif attn_mask.dim() == 3:
|
||||||
|
if list(attn_mask.size()) != [
|
||||||
|
bsz * num_heads,
|
||||||
|
query.size(0),
|
||||||
|
key.size(0),
|
||||||
|
]:
|
||||||
|
raise RuntimeError(
|
||||||
|
"The size of the 3D attn_mask is not correct."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"attn_mask's dimension {} is not supported".format(
|
||||||
|
attn_mask.dim()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# attn_mask's dim is 3 now.
|
||||||
|
|
||||||
|
# convert ByteTensor key_padding_mask to bool
|
||||||
|
if (
|
||||||
|
key_padding_mask is not None
|
||||||
|
and key_padding_mask.dtype == torch.uint8
|
||||||
|
):
|
||||||
|
warnings.warn(
|
||||||
|
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
|
||||||
|
)
|
||||||
|
key_padding_mask = key_padding_mask.to(torch.bool)
|
||||||
|
|
||||||
|
q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim)
|
||||||
|
k = k.contiguous().view(-1, bsz, num_heads, head_dim)
|
||||||
|
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
||||||
|
|
||||||
|
src_len = k.size(0)
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
|
||||||
|
key_padding_mask.size(0), bsz
|
||||||
|
)
|
||||||
|
assert key_padding_mask.size(1) == src_len, "{} == {}".format(
|
||||||
|
key_padding_mask.size(1), src_len
|
||||||
|
)
|
||||||
|
|
||||||
|
q = q.transpose(0, 1) # (batch, time1, head, d_k)
|
||||||
|
|
||||||
|
pos_emb_bsz = pos_emb.size(0)
|
||||||
|
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
||||||
|
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
||||||
|
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
||||||
|
|
||||||
|
q_with_bias_u = (q + self.pos_bias_u).transpose(
|
||||||
|
1, 2
|
||||||
|
) # (batch, head, time1, d_k)
|
||||||
|
|
||||||
|
q_with_bias_v = (q + self.pos_bias_v).transpose(
|
||||||
|
1, 2
|
||||||
|
) # (batch, head, time1, d_k)
|
||||||
|
|
||||||
|
# compute attention score
|
||||||
|
# first compute matrix a and matrix c
|
||||||
|
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||||
|
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
|
||||||
|
matrix_ac = torch.matmul(
|
||||||
|
q_with_bias_u, k
|
||||||
|
) # (batch, head, time1, time2)
|
||||||
|
|
||||||
|
# compute matrix b and matrix d
|
||||||
|
matrix_bd = torch.matmul(
|
||||||
|
q_with_bias_v, p.transpose(-2, -1)
|
||||||
|
) # (batch, head, time1, 2*time1-1)
|
||||||
|
matrix_bd = self.rel_shift(matrix_bd)
|
||||||
|
|
||||||
|
attn_output_weights = (
|
||||||
|
matrix_ac + matrix_bd
|
||||||
|
) * scaling # (batch, head, time1, time2)
|
||||||
|
|
||||||
|
attn_output_weights = attn_output_weights.view(
|
||||||
|
bsz * num_heads, tgt_len, -1
|
||||||
|
)
|
||||||
|
|
||||||
|
assert list(attn_output_weights.size()) == [
|
||||||
|
bsz * num_heads,
|
||||||
|
tgt_len,
|
||||||
|
src_len,
|
||||||
|
]
|
||||||
|
|
||||||
|
if attn_mask is not None:
|
||||||
|
if attn_mask.dtype == torch.bool:
|
||||||
|
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
|
||||||
|
else:
|
||||||
|
attn_output_weights += attn_mask
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
attn_output_weights = attn_output_weights.view(
|
||||||
|
bsz, num_heads, tgt_len, src_len
|
||||||
|
)
|
||||||
|
attn_output_weights = attn_output_weights.masked_fill(
|
||||||
|
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
||||||
|
float("-inf"),
|
||||||
|
)
|
||||||
|
attn_output_weights = attn_output_weights.view(
|
||||||
|
bsz * num_heads, tgt_len, src_len
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
|
||||||
|
attn_output_weights = nn.functional.dropout(
|
||||||
|
attn_output_weights, p=dropout_p, training=training
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = torch.bmm(attn_output_weights, v)
|
||||||
|
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
||||||
|
attn_output = (
|
||||||
|
attn_output.transpose(0, 1)
|
||||||
|
.contiguous()
|
||||||
|
.view(tgt_len, bsz, embed_dim)
|
||||||
|
)
|
||||||
|
attn_output = nn.functional.linear(
|
||||||
|
attn_output, out_proj_weight, out_proj_bias
|
||||||
|
)
|
||||||
|
|
||||||
|
if need_weights:
|
||||||
|
# average attention weights over heads
|
||||||
|
attn_output_weights = attn_output_weights.view(
|
||||||
|
bsz, num_heads, tgt_len, src_len
|
||||||
|
)
|
||||||
|
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
||||||
|
else:
|
||||||
|
return attn_output, None
|
||||||
|
|
||||||
class ConvolutionModule(nn.Module):
|
class ConvolutionModule(nn.Module):
|
||||||
"""ConvolutionModule in Conformer model.
|
"""ConvolutionModule in Conformer model.
|
||||||
Modified from
|
Modified from
|
||||||
|
Loading…
x
Reference in New Issue
Block a user