mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
support position encoding
This commit is contained in:
parent
4929faeed4
commit
630626a092
3
.flake8
3
.flake8
@ -9,8 +9,7 @@ per-file-ignores =
|
|||||||
egs/*/ASR/pruned_transducer_stateless*/*.py: E501,
|
egs/*/ASR/pruned_transducer_stateless*/*.py: E501,
|
||||||
egs/*/ASR/*/optim.py: E501,
|
egs/*/ASR/*/optim.py: E501,
|
||||||
egs/*/ASR/*/scaling.py: E501,
|
egs/*/ASR/*/scaling.py: E501,
|
||||||
egs/librispeech/ASR/conv_emformer_transducer_stateless/*.py: E501, E203
|
egs/librispeech/ASR/conv_emformer_transducer_stateless*/emformer.py: E501, E203
|
||||||
egs/librispeech/ASR/conv_emformer_transducer_stateless2/*.py: E501, E203
|
|
||||||
|
|
||||||
# invalid escape sequence (cause by tex formular), W605
|
# invalid escape sequence (cause by tex formular), W605
|
||||||
icefall/utils.py: E501, W605
|
icefall/utils.py: E501, W605
|
||||||
|
@ -35,7 +35,6 @@ from scaling import (
|
|||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
LOG_EPSILON = math.log(1e-10)
|
LOG_EPSILON = math.log(1e-10)
|
||||||
|
|
||||||
|
|
||||||
@ -434,6 +433,10 @@ class EmformerAttention(nn.Module):
|
|||||||
r"""Emformer layer attention module.
|
r"""Emformer layer attention module.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
chunk_length (int):
|
||||||
|
Length of chunk.
|
||||||
|
right_context_length (int):
|
||||||
|
Length of right context.
|
||||||
embed_dim (int):
|
embed_dim (int):
|
||||||
Embedding dimension.
|
Embedding dimension.
|
||||||
nhead (int):
|
nhead (int):
|
||||||
@ -448,6 +451,8 @@ class EmformerAttention(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
chunk_length: int,
|
||||||
|
right_context_length: int,
|
||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
nhead: int,
|
nhead: int,
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
@ -455,6 +460,8 @@ class EmformerAttention(nn.Module):
|
|||||||
negative_inf: float = -1e8,
|
negative_inf: float = -1e8,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.chunk_length = chunk_length
|
||||||
|
self.right_context_length = right_context_length
|
||||||
|
|
||||||
if embed_dim % nhead != 0:
|
if embed_dim % nhead != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -477,6 +484,26 @@ class EmformerAttention(nn.Module):
|
|||||||
embed_dim, embed_dim, bias=True, initial_scale=0.25
|
embed_dim, embed_dim, bias=True, initial_scale=0.25
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# linear transformation for positional encoding.
|
||||||
|
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
|
||||||
|
# these two learnable bias are used in matrix c and matrix d
|
||||||
|
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # noqa
|
||||||
|
self.pos_bias_u = nn.Parameter(torch.Tensor(nhead, self.head_dim))
|
||||||
|
self.pos_bias_v = nn.Parameter(torch.Tensor(nhead, self.head_dim))
|
||||||
|
self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach())
|
||||||
|
self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach())
|
||||||
|
self._reset_parameters()
|
||||||
|
|
||||||
|
def _pos_bias_u(self):
|
||||||
|
return self.pos_bias_u * self.pos_bias_u_scale.exp()
|
||||||
|
|
||||||
|
def _pos_bias_v(self):
|
||||||
|
return self.pos_bias_v * self.pos_bias_v_scale.exp()
|
||||||
|
|
||||||
|
def _reset_parameters(self) -> None:
|
||||||
|
nn.init.normal_(self.pos_bias_u, std=0.01)
|
||||||
|
nn.init.normal_(self.pos_bias_v, std=0.01)
|
||||||
|
|
||||||
def _gen_attention_probs(
|
def _gen_attention_probs(
|
||||||
self,
|
self,
|
||||||
attention_weights: torch.Tensor,
|
attention_weights: torch.Tensor,
|
||||||
@ -539,6 +566,8 @@ class EmformerAttention(nn.Module):
|
|||||||
right_context: torch.Tensor,
|
right_context: torch.Tensor,
|
||||||
memory: torch.Tensor,
|
memory: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
|
pos_emb: torch.Tensor,
|
||||||
|
rel_pos: torch.Tensor,
|
||||||
padding_mask: Optional[torch.Tensor] = None,
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
left_context_key: Optional[torch.Tensor] = None,
|
left_context_key: Optional[torch.Tensor] = None,
|
||||||
left_context_val: Optional[torch.Tensor] = None,
|
left_context_val: Optional[torch.Tensor] = None,
|
||||||
@ -556,26 +585,79 @@ class EmformerAttention(nn.Module):
|
|||||||
torch.cat([memory, right_context, utterance])
|
torch.cat([memory, right_context, utterance])
|
||||||
).chunk(chunks=2, dim=2)
|
).chunk(chunks=2, dim=2)
|
||||||
|
|
||||||
|
is_streaming_infer = False
|
||||||
if left_context_key is not None and left_context_val is not None:
|
if left_context_key is not None and left_context_val is not None:
|
||||||
# now compute key and value with
|
# now compute key and value with
|
||||||
# [memory, right context, left context, uttrance]
|
# [memory, right context, left context, uttrance]
|
||||||
# this is used in inference mode
|
# this is used in streaming inference mode
|
||||||
|
is_streaming_infer = True
|
||||||
key = torch.cat([key[: M + R], left_context_key, key[M + R :]])
|
key = torch.cat([key[: M + R], left_context_key, key[M + R :]])
|
||||||
value = torch.cat(
|
value = torch.cat(
|
||||||
[value[: M + R], left_context_val, value[M + R :]]
|
[value[: M + R], left_context_val, value[M + R :]]
|
||||||
)
|
)
|
||||||
Q = query.size(0)
|
Q = query.size(0)
|
||||||
# KV = key.size(0)
|
KV = key.size(0)
|
||||||
|
|
||||||
reshaped_query, reshaped_key, reshaped_value = [
|
reshaped_key = (
|
||||||
tensor.contiguous()
|
key.contiguous()
|
||||||
.view(-1, B * self.nhead, self.head_dim)
|
.view(KV, B, self.nhead, self.head_dim)
|
||||||
|
.permute(1, 2, 0, 3)
|
||||||
|
) # (B, nhead, KV, head_dim)
|
||||||
|
reshaped_value = (
|
||||||
|
value.contiguous()
|
||||||
|
.view(KV, B * self.nhead, self.head_dim)
|
||||||
.transpose(0, 1)
|
.transpose(0, 1)
|
||||||
for tensor in [query, key, value]
|
) # (B * nhead, KV, head_dim)
|
||||||
] # (B * nhead, Q or KV, head_dim)
|
query = (
|
||||||
attention_weights = torch.bmm(
|
(query * scaling).contiguous().view(Q, B, self.nhead, self.head_dim)
|
||||||
reshaped_query * scaling, reshaped_key.transpose(1, 2)
|
)
|
||||||
) # (B * nhead, Q, KV)
|
# (B, nhead, Q, head_dim)
|
||||||
|
query_with_bias_u = (query + self._pos_bias_u()).permute(1, 2, 0, 3)
|
||||||
|
query_with_bias_v = (query + self._pos_bias_v()).permute(1, 2, 0, 3)
|
||||||
|
|
||||||
|
PE = pos_emb.size(0)
|
||||||
|
# pos_emb contains flipped positive part and negative part
|
||||||
|
# for relative position i - j between query (i) and key (j)
|
||||||
|
if is_streaming_infer:
|
||||||
|
# i is the first frame in current chunk (query)
|
||||||
|
# j is the last frame in right context (key)
|
||||||
|
# Note: R is equal to self.right_context_length here
|
||||||
|
min_neg_abs = U + R - 1
|
||||||
|
# i is the last frame in right context (query)
|
||||||
|
# j is the first frame in the past context that memory bank can cover (key) # noqa
|
||||||
|
max_pos_abs = U + R + M * self.chunk_length - 1
|
||||||
|
else:
|
||||||
|
# i is the first frame in utterance (query)
|
||||||
|
# j is the last frame in the last chunk's right context (key)
|
||||||
|
min_neg_abs = U + self.right_context_length - 1
|
||||||
|
# i is the last frame in the last chunk's right context (query)
|
||||||
|
# j is the first frame in the utterance (key)
|
||||||
|
max_pos_abs = U + self.right_context_length - 1
|
||||||
|
assert PE == min_neg_abs + max_pos_abs + 1
|
||||||
|
pos_emb = (
|
||||||
|
self.linear_pos(pos_emb)
|
||||||
|
.view(1, PE, self.nhead, self.head_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
) # (1, nhead, PE, head_dim)
|
||||||
|
|
||||||
|
# content-based matrix-ac
|
||||||
|
matrix_ac = torch.matmul(
|
||||||
|
query_with_bias_u, reshaped_key.transpose(-2, -1)
|
||||||
|
) # (B, nhead, Q, KV)
|
||||||
|
|
||||||
|
# position-based matrix-bd
|
||||||
|
# (B, nhead, Q, PE)
|
||||||
|
matrix_bd = torch.matmul(query_with_bias_v, pos_emb.transpose(-2, -1))
|
||||||
|
# gather position-related scores using pre-computed relative position
|
||||||
|
assert rel_pos.shape == (Q, KV)
|
||||||
|
rel_pos = rel_pos.unsqueeze(0).unsqueeze(1).expand(B, self.nhead, Q, KV)
|
||||||
|
matrix_bd = torch.gather(
|
||||||
|
matrix_bd,
|
||||||
|
dim=-1,
|
||||||
|
index=rel_pos,
|
||||||
|
) # (B, nhead, Q, KV)
|
||||||
|
|
||||||
|
attention_weights = (matrix_ac + matrix_bd).view(B * self.nhead, Q, KV)
|
||||||
|
|
||||||
# compute attention probabilities
|
# compute attention probabilities
|
||||||
attention_probs = self._gen_attention_probs(
|
attention_probs = self._gen_attention_probs(
|
||||||
@ -600,6 +682,8 @@ class EmformerAttention(nn.Module):
|
|||||||
right_context: torch.Tensor,
|
right_context: torch.Tensor,
|
||||||
memory: torch.Tensor,
|
memory: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
|
pos_emb: torch.Tensor,
|
||||||
|
rel_pos: torch.Tensor,
|
||||||
padding_mask: Optional[torch.Tensor] = None,
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# TODO: Modify docs.
|
# TODO: Modify docs.
|
||||||
@ -647,6 +731,11 @@ class EmformerAttention(nn.Module):
|
|||||||
attention_mask (torch.Tensor):
|
attention_mask (torch.Tensor):
|
||||||
Pre-computed attention mask to simulate underlying chunk-wise
|
Pre-computed attention mask to simulate underlying chunk-wise
|
||||||
attention, with shape (Q, KV).
|
attention, with shape (Q, KV).
|
||||||
|
pos_emb (torch.Tensor):
|
||||||
|
Position embedding, with shape (PE, D),
|
||||||
|
where PE = 2 * (U + right_context_length) - 1.
|
||||||
|
rel_pos (torch.Tensor):
|
||||||
|
Relative positions, with shape (Q, KV).
|
||||||
padding_mask (torch.Tensor):
|
padding_mask (torch.Tensor):
|
||||||
Padding mask of key tensor, with shape (B, KV).
|
Padding mask of key tensor, with shape (B, KV).
|
||||||
|
|
||||||
@ -654,10 +743,12 @@ class EmformerAttention(nn.Module):
|
|||||||
Output of right context and utterance, with shape (R + U, B, D).
|
Output of right context and utterance, with shape (R + U, B, D).
|
||||||
"""
|
"""
|
||||||
output_right_context_utterance, _, _ = self._forward_impl(
|
output_right_context_utterance, _, _ = self._forward_impl(
|
||||||
utterance,
|
utterance=utterance,
|
||||||
right_context,
|
right_context=right_context,
|
||||||
memory,
|
memory=memory,
|
||||||
attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
pos_emb=pos_emb,
|
||||||
|
rel_pos=rel_pos,
|
||||||
padding_mask=padding_mask,
|
padding_mask=padding_mask,
|
||||||
)
|
)
|
||||||
return output_right_context_utterance
|
return output_right_context_utterance
|
||||||
@ -667,6 +758,8 @@ class EmformerAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
utterance: torch.Tensor,
|
utterance: torch.Tensor,
|
||||||
right_context: torch.Tensor,
|
right_context: torch.Tensor,
|
||||||
|
pos_emb: torch.Tensor,
|
||||||
|
rel_pos: torch.Tensor,
|
||||||
memory: torch.Tensor,
|
memory: torch.Tensor,
|
||||||
left_context_key: torch.Tensor,
|
left_context_key: torch.Tensor,
|
||||||
left_context_val: torch.Tensor,
|
left_context_val: torch.Tensor,
|
||||||
@ -700,6 +793,11 @@ class EmformerAttention(nn.Module):
|
|||||||
right_context (torch.Tensor):
|
right_context (torch.Tensor):
|
||||||
Right context frames, with shape (R, B, D),
|
Right context frames, with shape (R, B, D),
|
||||||
where R = right_context_length.
|
where R = right_context_length.
|
||||||
|
pos_emb (torch.Tensor):
|
||||||
|
Position embedding, with shape (PE, D),
|
||||||
|
where PE = 2 * (U + R) + M * chunk_length - 1.
|
||||||
|
rel_pos (torch.Tensor):
|
||||||
|
Relative positions, with shape (Q, KV).
|
||||||
memory (torch.Tensor):
|
memory (torch.Tensor):
|
||||||
Memory vectors, with shape (M, B, D), or empty tensor.
|
Memory vectors, with shape (M, B, D), or empty tensor.
|
||||||
left_context_key (torch,Tensor):
|
left_context_key (torch,Tensor):
|
||||||
@ -733,10 +831,12 @@ class EmformerAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
output_right_context_utterance, key, value = self._forward_impl(
|
output_right_context_utterance, key, value = self._forward_impl(
|
||||||
utterance,
|
utterance=utterance,
|
||||||
right_context,
|
right_context=right_context,
|
||||||
memory,
|
memory=memory,
|
||||||
attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
pos_emb=pos_emb,
|
||||||
|
rel_pos=rel_pos,
|
||||||
padding_mask=padding_mask,
|
padding_mask=padding_mask,
|
||||||
left_context_key=left_context_key,
|
left_context_key=left_context_key,
|
||||||
left_context_val=left_context_val,
|
left_context_val=left_context_val,
|
||||||
@ -796,6 +896,8 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.attention = EmformerAttention(
|
self.attention = EmformerAttention(
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
right_context_length=right_context_length,
|
||||||
embed_dim=d_model,
|
embed_dim=d_model,
|
||||||
nhead=nhead,
|
nhead=nhead,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
@ -898,6 +1000,8 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
right_context_utterance: torch.Tensor,
|
right_context_utterance: torch.Tensor,
|
||||||
R: int,
|
R: int,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
|
pos_emb: torch.Tensor,
|
||||||
|
rel_pos: torch.Tensor,
|
||||||
padding_mask: Optional[torch.Tensor] = None,
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Apply attention module in training and validation mode."""
|
"""Apply attention module in training and validation mode."""
|
||||||
@ -917,15 +1021,18 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
right_context=right_context,
|
right_context=right_context,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
pos_emb=pos_emb,
|
||||||
|
rel_pos=rel_pos,
|
||||||
padding_mask=padding_mask,
|
padding_mask=padding_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output_right_context_utterance
|
return output_right_context_utterance
|
||||||
|
|
||||||
def _apply_attention_module_infer(
|
def _apply_attention_module_infer(
|
||||||
self,
|
self,
|
||||||
right_context_utterance: torch.Tensor,
|
right_context_utterance: torch.Tensor,
|
||||||
R: int,
|
R: int,
|
||||||
|
pos_emb: torch.Tensor,
|
||||||
|
rel_pos: torch.Tensor,
|
||||||
attn_cache: List[torch.Tensor],
|
attn_cache: List[torch.Tensor],
|
||||||
padding_mask: Optional[torch.Tensor] = None,
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||||
@ -962,6 +1069,8 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
) = self.attention.infer(
|
) = self.attention.infer(
|
||||||
utterance=utterance,
|
utterance=utterance,
|
||||||
right_context=right_context,
|
right_context=right_context,
|
||||||
|
pos_emb=pos_emb,
|
||||||
|
rel_pos=rel_pos,
|
||||||
memory=pre_memory,
|
memory=pre_memory,
|
||||||
left_context_key=left_context_key,
|
left_context_key=left_context_key,
|
||||||
left_context_val=left_context_val,
|
left_context_val=left_context_val,
|
||||||
@ -977,6 +1086,8 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
utterance: torch.Tensor,
|
utterance: torch.Tensor,
|
||||||
right_context: torch.Tensor,
|
right_context: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
|
pos_emb: torch.Tensor,
|
||||||
|
rel_pos: torch.Tensor,
|
||||||
padding_mask: Optional[torch.Tensor] = None,
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
@ -996,6 +1107,11 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
attention_mask (torch.Tensor):
|
attention_mask (torch.Tensor):
|
||||||
Attention mask for underlying attention module,
|
Attention mask for underlying attention module,
|
||||||
with shape (Q, KV), where Q = R + U, KV = M + R + U.
|
with shape (Q, KV), where Q = R + U, KV = M + R + U.
|
||||||
|
pos_emb (torch.Tensor):
|
||||||
|
Position embedding, with shape (PE, D),
|
||||||
|
where PE = 2 * (U + right_context_length) - 1.
|
||||||
|
rel_pos (torch.Tensor):
|
||||||
|
Relative positions, with shape (Q, KV).
|
||||||
padding_mask (torch.Tensor):
|
padding_mask (torch.Tensor):
|
||||||
Padding mask of ker tensor, with shape (B, KV).
|
Padding mask of ker tensor, with shape (B, KV).
|
||||||
|
|
||||||
@ -1025,7 +1141,12 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
# emformer attention module
|
# emformer attention module
|
||||||
src_att = self._apply_attention_module_forward(
|
src_att = self._apply_attention_module_forward(
|
||||||
src, R, attention_mask, padding_mask=padding_mask
|
right_context_utterance=src,
|
||||||
|
R=R,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
pos_emb=pos_emb,
|
||||||
|
rel_pos=rel_pos,
|
||||||
|
padding_mask=padding_mask,
|
||||||
)
|
)
|
||||||
src = src + self.dropout(src_att)
|
src = src + self.dropout(src_att)
|
||||||
|
|
||||||
@ -1050,6 +1171,8 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
utterance: torch.Tensor,
|
utterance: torch.Tensor,
|
||||||
right_context: torch.Tensor,
|
right_context: torch.Tensor,
|
||||||
|
pos_emb: torch.Tensor,
|
||||||
|
rel_pos: torch.Tensor,
|
||||||
attn_cache: List[torch.Tensor],
|
attn_cache: List[torch.Tensor],
|
||||||
conv_cache: torch.Tensor,
|
conv_cache: torch.Tensor,
|
||||||
padding_mask: Optional[torch.Tensor] = None,
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
@ -1067,6 +1190,12 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
Utterance frames, with shape (U, B, D).
|
Utterance frames, with shape (U, B, D).
|
||||||
right_context (torch.Tensor):
|
right_context (torch.Tensor):
|
||||||
Right context frames, with shape (R, B, D).
|
Right context frames, with shape (R, B, D).
|
||||||
|
pos_emb (torch.Tensor):
|
||||||
|
Position embedding, with shape (PE, D),
|
||||||
|
where PE = 2 * (U + R) + M * chunk_length - 1.
|
||||||
|
rel_pos (torch.Tensor):
|
||||||
|
Relative positions, with shape (Q, KV),
|
||||||
|
where Q = R + U, KV = M + R + L + U.
|
||||||
attn_cache (List[torch.Tensor]):
|
attn_cache (List[torch.Tensor]):
|
||||||
Cached attention tensors generated in preceding computation,
|
Cached attention tensors generated in preceding computation,
|
||||||
including memory, key and value of left context.
|
including memory, key and value of left context.
|
||||||
@ -1090,7 +1219,12 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
# emformer attention module
|
# emformer attention module
|
||||||
src_att, attn_cache = self._apply_attention_module_infer(
|
src_att, attn_cache = self._apply_attention_module_infer(
|
||||||
src, R, attn_cache, padding_mask=padding_mask
|
right_context_utterance=src,
|
||||||
|
R=R,
|
||||||
|
pos_emb=pos_emb,
|
||||||
|
rel_pos=rel_pos,
|
||||||
|
attn_cache=attn_cache,
|
||||||
|
padding_mask=padding_mask,
|
||||||
)
|
)
|
||||||
src = src + self.dropout(src_att)
|
src = src + self.dropout(src_att)
|
||||||
|
|
||||||
@ -1187,6 +1321,7 @@ class EmformerEncoder(nn.Module):
|
|||||||
|
|
||||||
self.use_memory = memory_size > 0
|
self.use_memory = memory_size > 0
|
||||||
|
|
||||||
|
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||||
self.emformer_layers = nn.ModuleList(
|
self.emformer_layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
EmformerEncoderLayer(
|
EmformerEncoderLayer(
|
||||||
@ -1215,7 +1350,9 @@ class EmformerEncoder(nn.Module):
|
|||||||
self.memory_size = memory_size
|
self.memory_size = memory_size
|
||||||
self.cnn_module_kernel = cnn_module_kernel
|
self.cnn_module_kernel = cnn_module_kernel
|
||||||
|
|
||||||
def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor:
|
def _gen_right_context(
|
||||||
|
self, x: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Hard copy each chunk's right context and concat them."""
|
"""Hard copy each chunk's right context and concat them."""
|
||||||
T = x.shape[0]
|
T = x.shape[0]
|
||||||
num_chunks = math.ceil(
|
num_chunks = math.ceil(
|
||||||
@ -1235,9 +1372,9 @@ class EmformerEncoder(nn.Module):
|
|||||||
indexes,
|
indexes,
|
||||||
torch.arange(T - self.right_context_length, T).unsqueeze(0),
|
torch.arange(T - self.right_context_length, T).unsqueeze(0),
|
||||||
]
|
]
|
||||||
)
|
).reshape(-1)
|
||||||
right_context_blocks = x[indexes.reshape(-1)]
|
right_context_blocks = x[indexes]
|
||||||
return right_context_blocks
|
return right_context_blocks, indexes
|
||||||
|
|
||||||
def _gen_attention_mask_col_widths(
|
def _gen_attention_mask_col_widths(
|
||||||
self, chunk_idx: int, U: int
|
self, chunk_idx: int, U: int
|
||||||
@ -1381,10 +1518,33 @@ class EmformerEncoder(nn.Module):
|
|||||||
- output_lengths, with shape (B,), without containing the
|
- output_lengths, with shape (B,), without containing the
|
||||||
right_context at the end.
|
right_context at the end.
|
||||||
"""
|
"""
|
||||||
U = x.size(0) - self.right_context_length
|
x, pos_emb = self.encoder_pos(x, pos_len=x.size(0), neg_len=x.size(0))
|
||||||
|
|
||||||
right_context = self._gen_right_context(x)
|
U = x.size(0) - self.right_context_length
|
||||||
|
right_context, right_context_indexes = self._gen_right_context(x)
|
||||||
|
utterance_indexes = torch.arange(0, U)
|
||||||
utterance = x[:U]
|
utterance = x[:U]
|
||||||
|
num_chunks = math.ceil(U / self.chunk_length)
|
||||||
|
memory_indexes = (
|
||||||
|
torch.arange(
|
||||||
|
self.chunk_length // 2,
|
||||||
|
(num_chunks - 1) * self.chunk_length,
|
||||||
|
self.chunk_length,
|
||||||
|
)
|
||||||
|
if num_chunks > 1
|
||||||
|
else torch.empty(0).to(dtype=utterance_indexes.dtype)
|
||||||
|
)
|
||||||
|
query_indexes = torch.cat(
|
||||||
|
[right_context_indexes, utterance_indexes]
|
||||||
|
).to(device=x.device)
|
||||||
|
key_indexes = torch.cat(
|
||||||
|
[memory_indexes, right_context_indexes, utterance_indexes]
|
||||||
|
).to(device=x.device)
|
||||||
|
# calculate relative position and flip sign
|
||||||
|
rel_pos = -(query_indexes.unsqueeze(1) - key_indexes.unsqueeze(0))
|
||||||
|
# shift to start from zero
|
||||||
|
rel_pos = rel_pos - rel_pos.min()
|
||||||
|
|
||||||
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
||||||
attention_mask = self._gen_attention_mask(utterance)
|
attention_mask = self._gen_attention_mask(utterance)
|
||||||
|
|
||||||
@ -1394,9 +1554,11 @@ class EmformerEncoder(nn.Module):
|
|||||||
output = utterance
|
output = utterance
|
||||||
for layer in self.emformer_layers:
|
for layer in self.emformer_layers:
|
||||||
output, right_context = layer(
|
output, right_context = layer(
|
||||||
output,
|
utterance=output,
|
||||||
right_context,
|
right_context=right_context,
|
||||||
attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
pos_emb=pos_emb,
|
||||||
|
rel_pos=rel_pos,
|
||||||
padding_mask=padding_mask,
|
padding_mask=padding_mask,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
)
|
)
|
||||||
@ -1445,6 +1607,7 @@ class EmformerEncoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
assert num_processed_frames.shape == (x.size(1),)
|
assert num_processed_frames.shape == (x.size(1),)
|
||||||
|
|
||||||
|
# check the shapes of states
|
||||||
attn_caches = states[0]
|
attn_caches = states[0]
|
||||||
assert len(attn_caches) == self.num_encoder_layers, len(attn_caches)
|
assert len(attn_caches) == self.num_encoder_layers, len(attn_caches)
|
||||||
for i in range(len(attn_caches)):
|
for i in range(len(attn_caches)):
|
||||||
@ -1473,6 +1636,11 @@ class EmformerEncoder(nn.Module):
|
|||||||
self.cnn_module_kernel - 1,
|
self.cnn_module_kernel - 1,
|
||||||
), conv_caches[i].shape
|
), conv_caches[i].shape
|
||||||
|
|
||||||
|
tot_past_length = self.memory_size * self.chunk_length
|
||||||
|
x, pos_emb = self.encoder_pos(
|
||||||
|
x, pos_len=x.size(0) + tot_past_length, neg_len=x.size(0)
|
||||||
|
)
|
||||||
|
|
||||||
right_context = x[-self.right_context_length :]
|
right_context = x[-self.right_context_length :]
|
||||||
utterance = x[: -self.right_context_length]
|
utterance = x[: -self.right_context_length]
|
||||||
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
||||||
@ -1504,6 +1672,36 @@ class EmformerEncoder(nn.Module):
|
|||||||
dim=1,
|
dim=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# calculate relative position
|
||||||
|
memory_indexes = torch.arange(
|
||||||
|
self.chunk_length // 2, tot_past_length, self.chunk_length
|
||||||
|
)
|
||||||
|
left_context_indexes = torch.arange(
|
||||||
|
tot_past_length - self.left_context_length, tot_past_length
|
||||||
|
)
|
||||||
|
utterance_indexes = torch.arange(
|
||||||
|
tot_past_length, tot_past_length + utterance.size(0)
|
||||||
|
)
|
||||||
|
right_context_indexes = torch.arange(
|
||||||
|
tot_past_length + utterance.size(0),
|
||||||
|
tot_past_length + utterance.size(0) + right_context.size(0),
|
||||||
|
)
|
||||||
|
query_indexes = torch.cat(
|
||||||
|
[right_context_indexes, utterance_indexes]
|
||||||
|
).to(device=x.device)
|
||||||
|
key_indexes = torch.cat(
|
||||||
|
[
|
||||||
|
memory_indexes,
|
||||||
|
right_context_indexes,
|
||||||
|
left_context_indexes,
|
||||||
|
utterance_indexes,
|
||||||
|
]
|
||||||
|
).to(device=x.device)
|
||||||
|
# calculate relative position and flip sign
|
||||||
|
rel_pos = -(query_indexes.unsqueeze(1) - key_indexes.unsqueeze(0))
|
||||||
|
# shift to start from zero
|
||||||
|
rel_pos = rel_pos - rel_pos.min()
|
||||||
|
|
||||||
output = utterance
|
output = utterance
|
||||||
output_attn_caches: List[List[torch.Tensor]] = []
|
output_attn_caches: List[List[torch.Tensor]] = []
|
||||||
output_conv_caches: List[torch.Tensor] = []
|
output_conv_caches: List[torch.Tensor] = []
|
||||||
@ -1514,8 +1712,10 @@ class EmformerEncoder(nn.Module):
|
|||||||
output_attn_cache,
|
output_attn_cache,
|
||||||
output_conv_cache,
|
output_conv_cache,
|
||||||
) = layer.infer(
|
) = layer.infer(
|
||||||
output,
|
utterance=output,
|
||||||
right_context,
|
right_context=right_context,
|
||||||
|
pos_emb=pos_emb,
|
||||||
|
rel_pos=rel_pos,
|
||||||
padding_mask=padding_mask,
|
padding_mask=padding_mask,
|
||||||
attn_cache=attn_caches[layer_idx],
|
attn_cache=attn_caches[layer_idx],
|
||||||
conv_cache=conv_caches[layer_idx],
|
conv_cache=conv_caches[layer_idx],
|
||||||
@ -1597,6 +1797,10 @@ class Emformer(EncoderInterface):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"right_context_length must be 0 or a mutiple of subsampling_factor." # noqa
|
"right_context_length must be 0 or a mutiple of subsampling_factor." # noqa
|
||||||
)
|
)
|
||||||
|
if memory_size > 0 and memory_size * chunk_length < left_context_length:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"memory_size * chunk_length must not be smaller than left_context_length." # noqa
|
||||||
|
)
|
||||||
|
|
||||||
# self.encoder_embed converts the input of shape (N, T, num_features)
|
# self.encoder_embed converts the input of shape (N, T, num_features)
|
||||||
# to the shape (N, T//subsampling_factor, d_model).
|
# to the shape (N, T//subsampling_factor, d_model).
|
||||||
@ -1822,3 +2026,119 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
x = self.out_norm(x)
|
x = self.out_norm(x)
|
||||||
x = self.out_balancer(x)
|
x = self.out_balancer(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class RelPositionalEncoding(torch.nn.Module):
|
||||||
|
"""Relative positional encoding module.
|
||||||
|
|
||||||
|
See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" # noqa
|
||||||
|
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py # noqa
|
||||||
|
|
||||||
|
Suppose:
|
||||||
|
i -> position of query,
|
||||||
|
j -> position of key(value),
|
||||||
|
we use positive relative position embedding when key(value) is to the
|
||||||
|
left of query(i.e., i > j) and negative embedding otherwise.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
d_model: Embedding dimension.
|
||||||
|
dropout: Dropout rate.
|
||||||
|
max_len: Maximum input length.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, d_model: int, dropout: float, max_len: int = 5000
|
||||||
|
) -> None:
|
||||||
|
"""Construct an PositionalEncoding object."""
|
||||||
|
super(RelPositionalEncoding, self).__init__()
|
||||||
|
self.d_model = d_model
|
||||||
|
self.dropout = torch.nn.Dropout(p=dropout)
|
||||||
|
self.pe = None
|
||||||
|
self.pos_len = max_len
|
||||||
|
self.neg_len = max_len
|
||||||
|
self.gen_pe_positive()
|
||||||
|
self.gen_pe_negative()
|
||||||
|
|
||||||
|
def gen_pe_positive(self) -> None:
|
||||||
|
"""Generate the positive positional encodings."""
|
||||||
|
pe_positive = torch.zeros(self.pos_len, self.d_model)
|
||||||
|
position_positive = torch.arange(
|
||||||
|
0, self.pos_len, dtype=torch.float32
|
||||||
|
).unsqueeze(1)
|
||||||
|
div_term = torch.exp(
|
||||||
|
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||||
|
* -(math.log(10000.0) / self.d_model)
|
||||||
|
)
|
||||||
|
pe_positive[:, 0::2] = torch.sin(position_positive * div_term)
|
||||||
|
pe_positive[:, 1::2] = torch.cos(position_positive * div_term)
|
||||||
|
# Reserve the order of positive indices and concat both positive and
|
||||||
|
# negative indices. This is used to support the shifting trick
|
||||||
|
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" # noqa
|
||||||
|
self.pe_positive = torch.flip(pe_positive, [0])
|
||||||
|
|
||||||
|
def gen_pe_negative(self) -> None:
|
||||||
|
"""Generate the negative positional encodings."""
|
||||||
|
# Suppose `i` means to the position of query vecotr and `j` means the
|
||||||
|
# position of key vector. We use positive relative positions when keys
|
||||||
|
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
||||||
|
pe_negative = torch.zeros(self.neg_len, self.d_model)
|
||||||
|
position_negative = torch.arange(
|
||||||
|
0, self.neg_len, dtype=torch.float32
|
||||||
|
).unsqueeze(1)
|
||||||
|
div_term = torch.exp(
|
||||||
|
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||||
|
* -(math.log(10000.0) / self.d_model)
|
||||||
|
)
|
||||||
|
pe_negative[:, 0::2] = torch.sin(-1 * position_negative * div_term)
|
||||||
|
pe_negative[:, 1::2] = torch.cos(-1 * position_negative * div_term)
|
||||||
|
self.pe_negative = pe_negative
|
||||||
|
|
||||||
|
def get_pe(
|
||||||
|
self,
|
||||||
|
pos_len: int,
|
||||||
|
neg_len: int,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Get positional encoding given positive length and negative length."""
|
||||||
|
if self.pe_positive.dtype != dtype or str(
|
||||||
|
self.pe_positive.device
|
||||||
|
) != str(device):
|
||||||
|
self.pe_positive = self.pe_positive.to(dtype=dtype, device=device)
|
||||||
|
if self.pe_negative.dtype != dtype or str(
|
||||||
|
self.pe_negative.device
|
||||||
|
) != str(device):
|
||||||
|
self.pe_negative = self.pe_negative.to(dtype=dtype, device=device)
|
||||||
|
pe = torch.cat(
|
||||||
|
[
|
||||||
|
self.pe_positive[self.pos_len - pos_len :],
|
||||||
|
self.pe_negative[1:neg_len],
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
return pe
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
pos_len: int,
|
||||||
|
neg_len: int,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Scale input x and get positional encoding.
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor (`*`).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor:
|
||||||
|
Encoded tensor of shape (`*`).
|
||||||
|
torch.Tensor:
|
||||||
|
Position embedding of shape (pos_len + neg_len - 1, `*`).
|
||||||
|
"""
|
||||||
|
if pos_len > self.pos_len:
|
||||||
|
self.pos_len = pos_len
|
||||||
|
self.gen_pe_positive()
|
||||||
|
if neg_len > self.neg_len:
|
||||||
|
self.neg_len = neg_len
|
||||||
|
self.gen_pe_negative()
|
||||||
|
pos_emb = self.get_pe(pos_len, neg_len, x.device, x.dtype)
|
||||||
|
return self.dropout(x), self.dropout(pos_emb)
|
||||||
|
@ -114,8 +114,12 @@ def test_state_stack_unstack():
|
|||||||
for _ in range(num_encoder_layers)
|
for _ in range(num_encoder_layers)
|
||||||
]
|
]
|
||||||
states = [attn_caches, conv_caches]
|
states = [attn_caches, conv_caches]
|
||||||
x = torch.randn(batch_size, 23, num_features)
|
x = torch.randn(
|
||||||
x_lens = torch.full((batch_size,), 23)
|
batch_size, chunk_length + right_context_length + 3, num_features
|
||||||
|
)
|
||||||
|
x_lens = torch.full(
|
||||||
|
(batch_size,), chunk_length + right_context_length + 3
|
||||||
|
)
|
||||||
num_processed_frames = torch.full((batch_size,), 0)
|
num_processed_frames = torch.full((batch_size,), 0)
|
||||||
y, y_lens, states = model.infer(
|
y, y_lens, states = model.infer(
|
||||||
x, x_lens, num_processed_frames=num_processed_frames, states=states
|
x, x_lens, num_processed_frames=num_processed_frames, states=states
|
||||||
@ -172,8 +176,10 @@ def test_torchscript_consistency_infer():
|
|||||||
for _ in range(num_encoder_layers)
|
for _ in range(num_encoder_layers)
|
||||||
]
|
]
|
||||||
states = [attn_caches, conv_caches]
|
states = [attn_caches, conv_caches]
|
||||||
x = torch.randn(batch_size, 23, num_features)
|
x = torch.randn(
|
||||||
x_lens = torch.full((batch_size,), 23)
|
batch_size, chunk_length + right_context_length + 3, num_features
|
||||||
|
)
|
||||||
|
x_lens = torch.full((batch_size,), chunk_length + right_context_length + 3)
|
||||||
num_processed_frames = torch.full((batch_size,), 0)
|
num_processed_frames = torch.full((batch_size,), 0)
|
||||||
y, y_lens, out_states = model.infer(
|
y, y_lens, out_states = model.infer(
|
||||||
x, x_lens, num_processed_frames=num_processed_frames, states=states
|
x, x_lens, num_processed_frames=num_processed_frames, states=states
|
||||||
@ -187,8 +193,38 @@ def test_torchscript_consistency_infer():
|
|||||||
assert torch.allclose(y, sc_y)
|
assert torch.allclose(y, sc_y)
|
||||||
|
|
||||||
|
|
||||||
|
def test_emformer_forward_shape():
|
||||||
|
num_features = 80
|
||||||
|
chunk_length = 32
|
||||||
|
encoder_dim = 512
|
||||||
|
num_encoder_layers = 2
|
||||||
|
kernel_size = 31
|
||||||
|
left_context_length = 32
|
||||||
|
right_context_length = 8
|
||||||
|
memory_size = 32
|
||||||
|
batch_size = 2
|
||||||
|
|
||||||
|
model = Emformer(
|
||||||
|
num_features=num_features,
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
subsampling_factor=4,
|
||||||
|
d_model=encoder_dim,
|
||||||
|
num_encoder_layers=num_encoder_layers,
|
||||||
|
cnn_module_kernel=kernel_size,
|
||||||
|
left_context_length=left_context_length,
|
||||||
|
right_context_length=right_context_length,
|
||||||
|
memory_size=memory_size,
|
||||||
|
)
|
||||||
|
U = 2 * chunk_length
|
||||||
|
x = torch.randn(batch_size, U + right_context_length + 3, num_features)
|
||||||
|
x_lens = torch.full((batch_size,), U + right_context_length + 3)
|
||||||
|
output, output_lengths = model(x, x_lens)
|
||||||
|
assert output.shape == (batch_size, U >> 2, encoder_dim)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_convolution_module_forward()
|
test_convolution_module_forward()
|
||||||
test_convolution_module_infer()
|
test_convolution_module_infer()
|
||||||
test_state_stack_unstack()
|
test_state_stack_unstack()
|
||||||
test_torchscript_consistency_infer()
|
test_torchscript_consistency_infer()
|
||||||
|
test_emformer_forward_shape()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user