support position encoding

This commit is contained in:
yaozengwei 2022-06-26 21:39:06 +08:00
parent 4929faeed4
commit 630626a092
3 changed files with 394 additions and 39 deletions

View File

@ -9,8 +9,7 @@ per-file-ignores =
egs/*/ASR/pruned_transducer_stateless*/*.py: E501, egs/*/ASR/pruned_transducer_stateless*/*.py: E501,
egs/*/ASR/*/optim.py: E501, egs/*/ASR/*/optim.py: E501,
egs/*/ASR/*/scaling.py: E501, egs/*/ASR/*/scaling.py: E501,
egs/librispeech/ASR/conv_emformer_transducer_stateless/*.py: E501, E203 egs/librispeech/ASR/conv_emformer_transducer_stateless*/emformer.py: E501, E203
egs/librispeech/ASR/conv_emformer_transducer_stateless2/*.py: E501, E203
# invalid escape sequence (cause by tex formular), W605 # invalid escape sequence (cause by tex formular), W605
icefall/utils.py: E501, W605 icefall/utils.py: E501, W605

View File

@ -35,7 +35,6 @@ from scaling import (
from icefall.utils import make_pad_mask from icefall.utils import make_pad_mask
LOG_EPSILON = math.log(1e-10) LOG_EPSILON = math.log(1e-10)
@ -434,6 +433,10 @@ class EmformerAttention(nn.Module):
r"""Emformer layer attention module. r"""Emformer layer attention module.
Args: Args:
chunk_length (int):
Length of chunk.
right_context_length (int):
Length of right context.
embed_dim (int): embed_dim (int):
Embedding dimension. Embedding dimension.
nhead (int): nhead (int):
@ -448,6 +451,8 @@ class EmformerAttention(nn.Module):
def __init__( def __init__(
self, self,
chunk_length: int,
right_context_length: int,
embed_dim: int, embed_dim: int,
nhead: int, nhead: int,
dropout: float = 0.0, dropout: float = 0.0,
@ -455,6 +460,8 @@ class EmformerAttention(nn.Module):
negative_inf: float = -1e8, negative_inf: float = -1e8,
): ):
super().__init__() super().__init__()
self.chunk_length = chunk_length
self.right_context_length = right_context_length
if embed_dim % nhead != 0: if embed_dim % nhead != 0:
raise ValueError( raise ValueError(
@ -477,6 +484,26 @@ class EmformerAttention(nn.Module):
embed_dim, embed_dim, bias=True, initial_scale=0.25 embed_dim, embed_dim, bias=True, initial_scale=0.25
) )
# linear transformation for positional encoding.
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # noqa
self.pos_bias_u = nn.Parameter(torch.Tensor(nhead, self.head_dim))
self.pos_bias_v = nn.Parameter(torch.Tensor(nhead, self.head_dim))
self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach())
self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach())
self._reset_parameters()
def _pos_bias_u(self):
return self.pos_bias_u * self.pos_bias_u_scale.exp()
def _pos_bias_v(self):
return self.pos_bias_v * self.pos_bias_v_scale.exp()
def _reset_parameters(self) -> None:
nn.init.normal_(self.pos_bias_u, std=0.01)
nn.init.normal_(self.pos_bias_v, std=0.01)
def _gen_attention_probs( def _gen_attention_probs(
self, self,
attention_weights: torch.Tensor, attention_weights: torch.Tensor,
@ -539,6 +566,8 @@ class EmformerAttention(nn.Module):
right_context: torch.Tensor, right_context: torch.Tensor,
memory: torch.Tensor, memory: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
pos_emb: torch.Tensor,
rel_pos: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None,
left_context_key: Optional[torch.Tensor] = None, left_context_key: Optional[torch.Tensor] = None,
left_context_val: Optional[torch.Tensor] = None, left_context_val: Optional[torch.Tensor] = None,
@ -556,26 +585,79 @@ class EmformerAttention(nn.Module):
torch.cat([memory, right_context, utterance]) torch.cat([memory, right_context, utterance])
).chunk(chunks=2, dim=2) ).chunk(chunks=2, dim=2)
is_streaming_infer = False
if left_context_key is not None and left_context_val is not None: if left_context_key is not None and left_context_val is not None:
# now compute key and value with # now compute key and value with
# [memory, right context, left context, uttrance] # [memory, right context, left context, uttrance]
# this is used in inference mode # this is used in streaming inference mode
is_streaming_infer = True
key = torch.cat([key[: M + R], left_context_key, key[M + R :]]) key = torch.cat([key[: M + R], left_context_key, key[M + R :]])
value = torch.cat( value = torch.cat(
[value[: M + R], left_context_val, value[M + R :]] [value[: M + R], left_context_val, value[M + R :]]
) )
Q = query.size(0) Q = query.size(0)
# KV = key.size(0) KV = key.size(0)
reshaped_query, reshaped_key, reshaped_value = [ reshaped_key = (
tensor.contiguous() key.contiguous()
.view(-1, B * self.nhead, self.head_dim) .view(KV, B, self.nhead, self.head_dim)
.permute(1, 2, 0, 3)
) # (B, nhead, KV, head_dim)
reshaped_value = (
value.contiguous()
.view(KV, B * self.nhead, self.head_dim)
.transpose(0, 1) .transpose(0, 1)
for tensor in [query, key, value] ) # (B * nhead, KV, head_dim)
] # (B * nhead, Q or KV, head_dim) query = (
attention_weights = torch.bmm( (query * scaling).contiguous().view(Q, B, self.nhead, self.head_dim)
reshaped_query * scaling, reshaped_key.transpose(1, 2) )
) # (B * nhead, Q, KV) # (B, nhead, Q, head_dim)
query_with_bias_u = (query + self._pos_bias_u()).permute(1, 2, 0, 3)
query_with_bias_v = (query + self._pos_bias_v()).permute(1, 2, 0, 3)
PE = pos_emb.size(0)
# pos_emb contains flipped positive part and negative part
# for relative position i - j between query (i) and key (j)
if is_streaming_infer:
# i is the first frame in current chunk (query)
# j is the last frame in right context (key)
# Note: R is equal to self.right_context_length here
min_neg_abs = U + R - 1
# i is the last frame in right context (query)
# j is the first frame in the past context that memory bank can cover (key) # noqa
max_pos_abs = U + R + M * self.chunk_length - 1
else:
# i is the first frame in utterance (query)
# j is the last frame in the last chunk's right context (key)
min_neg_abs = U + self.right_context_length - 1
# i is the last frame in the last chunk's right context (query)
# j is the first frame in the utterance (key)
max_pos_abs = U + self.right_context_length - 1
assert PE == min_neg_abs + max_pos_abs + 1
pos_emb = (
self.linear_pos(pos_emb)
.view(1, PE, self.nhead, self.head_dim)
.transpose(1, 2)
) # (1, nhead, PE, head_dim)
# content-based matrix-ac
matrix_ac = torch.matmul(
query_with_bias_u, reshaped_key.transpose(-2, -1)
) # (B, nhead, Q, KV)
# position-based matrix-bd
# (B, nhead, Q, PE)
matrix_bd = torch.matmul(query_with_bias_v, pos_emb.transpose(-2, -1))
# gather position-related scores using pre-computed relative position
assert rel_pos.shape == (Q, KV)
rel_pos = rel_pos.unsqueeze(0).unsqueeze(1).expand(B, self.nhead, Q, KV)
matrix_bd = torch.gather(
matrix_bd,
dim=-1,
index=rel_pos,
) # (B, nhead, Q, KV)
attention_weights = (matrix_ac + matrix_bd).view(B * self.nhead, Q, KV)
# compute attention probabilities # compute attention probabilities
attention_probs = self._gen_attention_probs( attention_probs = self._gen_attention_probs(
@ -600,6 +682,8 @@ class EmformerAttention(nn.Module):
right_context: torch.Tensor, right_context: torch.Tensor,
memory: torch.Tensor, memory: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
pos_emb: torch.Tensor,
rel_pos: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# TODO: Modify docs. # TODO: Modify docs.
@ -647,6 +731,11 @@ class EmformerAttention(nn.Module):
attention_mask (torch.Tensor): attention_mask (torch.Tensor):
Pre-computed attention mask to simulate underlying chunk-wise Pre-computed attention mask to simulate underlying chunk-wise
attention, with shape (Q, KV). attention, with shape (Q, KV).
pos_emb (torch.Tensor):
Position embedding, with shape (PE, D),
where PE = 2 * (U + right_context_length) - 1.
rel_pos (torch.Tensor):
Relative positions, with shape (Q, KV).
padding_mask (torch.Tensor): padding_mask (torch.Tensor):
Padding mask of key tensor, with shape (B, KV). Padding mask of key tensor, with shape (B, KV).
@ -654,10 +743,12 @@ class EmformerAttention(nn.Module):
Output of right context and utterance, with shape (R + U, B, D). Output of right context and utterance, with shape (R + U, B, D).
""" """
output_right_context_utterance, _, _ = self._forward_impl( output_right_context_utterance, _, _ = self._forward_impl(
utterance, utterance=utterance,
right_context, right_context=right_context,
memory, memory=memory,
attention_mask, attention_mask=attention_mask,
pos_emb=pos_emb,
rel_pos=rel_pos,
padding_mask=padding_mask, padding_mask=padding_mask,
) )
return output_right_context_utterance return output_right_context_utterance
@ -667,6 +758,8 @@ class EmformerAttention(nn.Module):
self, self,
utterance: torch.Tensor, utterance: torch.Tensor,
right_context: torch.Tensor, right_context: torch.Tensor,
pos_emb: torch.Tensor,
rel_pos: torch.Tensor,
memory: torch.Tensor, memory: torch.Tensor,
left_context_key: torch.Tensor, left_context_key: torch.Tensor,
left_context_val: torch.Tensor, left_context_val: torch.Tensor,
@ -700,6 +793,11 @@ class EmformerAttention(nn.Module):
right_context (torch.Tensor): right_context (torch.Tensor):
Right context frames, with shape (R, B, D), Right context frames, with shape (R, B, D),
where R = right_context_length. where R = right_context_length.
pos_emb (torch.Tensor):
Position embedding, with shape (PE, D),
where PE = 2 * (U + R) + M * chunk_length - 1.
rel_pos (torch.Tensor):
Relative positions, with shape (Q, KV).
memory (torch.Tensor): memory (torch.Tensor):
Memory vectors, with shape (M, B, D), or empty tensor. Memory vectors, with shape (M, B, D), or empty tensor.
left_context_key (torch,Tensor): left_context_key (torch,Tensor):
@ -733,10 +831,12 @@ class EmformerAttention(nn.Module):
) )
output_right_context_utterance, key, value = self._forward_impl( output_right_context_utterance, key, value = self._forward_impl(
utterance, utterance=utterance,
right_context, right_context=right_context,
memory, memory=memory,
attention_mask, attention_mask=attention_mask,
pos_emb=pos_emb,
rel_pos=rel_pos,
padding_mask=padding_mask, padding_mask=padding_mask,
left_context_key=left_context_key, left_context_key=left_context_key,
left_context_val=left_context_val, left_context_val=left_context_val,
@ -796,6 +896,8 @@ class EmformerEncoderLayer(nn.Module):
super().__init__() super().__init__()
self.attention = EmformerAttention( self.attention = EmformerAttention(
chunk_length=chunk_length,
right_context_length=right_context_length,
embed_dim=d_model, embed_dim=d_model,
nhead=nhead, nhead=nhead,
dropout=dropout, dropout=dropout,
@ -898,6 +1000,8 @@ class EmformerEncoderLayer(nn.Module):
right_context_utterance: torch.Tensor, right_context_utterance: torch.Tensor,
R: int, R: int,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
pos_emb: torch.Tensor,
rel_pos: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Apply attention module in training and validation mode.""" """Apply attention module in training and validation mode."""
@ -917,15 +1021,18 @@ class EmformerEncoderLayer(nn.Module):
right_context=right_context, right_context=right_context,
memory=memory, memory=memory,
attention_mask=attention_mask, attention_mask=attention_mask,
pos_emb=pos_emb,
rel_pos=rel_pos,
padding_mask=padding_mask, padding_mask=padding_mask,
) )
return output_right_context_utterance return output_right_context_utterance
def _apply_attention_module_infer( def _apply_attention_module_infer(
self, self,
right_context_utterance: torch.Tensor, right_context_utterance: torch.Tensor,
R: int, R: int,
pos_emb: torch.Tensor,
rel_pos: torch.Tensor,
attn_cache: List[torch.Tensor], attn_cache: List[torch.Tensor],
padding_mask: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, List[torch.Tensor]]: ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
@ -962,6 +1069,8 @@ class EmformerEncoderLayer(nn.Module):
) = self.attention.infer( ) = self.attention.infer(
utterance=utterance, utterance=utterance,
right_context=right_context, right_context=right_context,
pos_emb=pos_emb,
rel_pos=rel_pos,
memory=pre_memory, memory=pre_memory,
left_context_key=left_context_key, left_context_key=left_context_key,
left_context_val=left_context_val, left_context_val=left_context_val,
@ -977,6 +1086,8 @@ class EmformerEncoderLayer(nn.Module):
utterance: torch.Tensor, utterance: torch.Tensor,
right_context: torch.Tensor, right_context: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
pos_emb: torch.Tensor,
rel_pos: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None,
warmup: float = 1.0, warmup: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
@ -996,6 +1107,11 @@ class EmformerEncoderLayer(nn.Module):
attention_mask (torch.Tensor): attention_mask (torch.Tensor):
Attention mask for underlying attention module, Attention mask for underlying attention module,
with shape (Q, KV), where Q = R + U, KV = M + R + U. with shape (Q, KV), where Q = R + U, KV = M + R + U.
pos_emb (torch.Tensor):
Position embedding, with shape (PE, D),
where PE = 2 * (U + right_context_length) - 1.
rel_pos (torch.Tensor):
Relative positions, with shape (Q, KV).
padding_mask (torch.Tensor): padding_mask (torch.Tensor):
Padding mask of ker tensor, with shape (B, KV). Padding mask of ker tensor, with shape (B, KV).
@ -1025,7 +1141,12 @@ class EmformerEncoderLayer(nn.Module):
# emformer attention module # emformer attention module
src_att = self._apply_attention_module_forward( src_att = self._apply_attention_module_forward(
src, R, attention_mask, padding_mask=padding_mask right_context_utterance=src,
R=R,
attention_mask=attention_mask,
pos_emb=pos_emb,
rel_pos=rel_pos,
padding_mask=padding_mask,
) )
src = src + self.dropout(src_att) src = src + self.dropout(src_att)
@ -1050,6 +1171,8 @@ class EmformerEncoderLayer(nn.Module):
self, self,
utterance: torch.Tensor, utterance: torch.Tensor,
right_context: torch.Tensor, right_context: torch.Tensor,
pos_emb: torch.Tensor,
rel_pos: torch.Tensor,
attn_cache: List[torch.Tensor], attn_cache: List[torch.Tensor],
conv_cache: torch.Tensor, conv_cache: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None,
@ -1067,6 +1190,12 @@ class EmformerEncoderLayer(nn.Module):
Utterance frames, with shape (U, B, D). Utterance frames, with shape (U, B, D).
right_context (torch.Tensor): right_context (torch.Tensor):
Right context frames, with shape (R, B, D). Right context frames, with shape (R, B, D).
pos_emb (torch.Tensor):
Position embedding, with shape (PE, D),
where PE = 2 * (U + R) + M * chunk_length - 1.
rel_pos (torch.Tensor):
Relative positions, with shape (Q, KV),
where Q = R + U, KV = M + R + L + U.
attn_cache (List[torch.Tensor]): attn_cache (List[torch.Tensor]):
Cached attention tensors generated in preceding computation, Cached attention tensors generated in preceding computation,
including memory, key and value of left context. including memory, key and value of left context.
@ -1090,7 +1219,12 @@ class EmformerEncoderLayer(nn.Module):
# emformer attention module # emformer attention module
src_att, attn_cache = self._apply_attention_module_infer( src_att, attn_cache = self._apply_attention_module_infer(
src, R, attn_cache, padding_mask=padding_mask right_context_utterance=src,
R=R,
pos_emb=pos_emb,
rel_pos=rel_pos,
attn_cache=attn_cache,
padding_mask=padding_mask,
) )
src = src + self.dropout(src_att) src = src + self.dropout(src_att)
@ -1187,6 +1321,7 @@ class EmformerEncoder(nn.Module):
self.use_memory = memory_size > 0 self.use_memory = memory_size > 0
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
self.emformer_layers = nn.ModuleList( self.emformer_layers = nn.ModuleList(
[ [
EmformerEncoderLayer( EmformerEncoderLayer(
@ -1215,7 +1350,9 @@ class EmformerEncoder(nn.Module):
self.memory_size = memory_size self.memory_size = memory_size
self.cnn_module_kernel = cnn_module_kernel self.cnn_module_kernel = cnn_module_kernel
def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: def _gen_right_context(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Hard copy each chunk's right context and concat them.""" """Hard copy each chunk's right context and concat them."""
T = x.shape[0] T = x.shape[0]
num_chunks = math.ceil( num_chunks = math.ceil(
@ -1235,9 +1372,9 @@ class EmformerEncoder(nn.Module):
indexes, indexes,
torch.arange(T - self.right_context_length, T).unsqueeze(0), torch.arange(T - self.right_context_length, T).unsqueeze(0),
] ]
) ).reshape(-1)
right_context_blocks = x[indexes.reshape(-1)] right_context_blocks = x[indexes]
return right_context_blocks return right_context_blocks, indexes
def _gen_attention_mask_col_widths( def _gen_attention_mask_col_widths(
self, chunk_idx: int, U: int self, chunk_idx: int, U: int
@ -1381,10 +1518,33 @@ class EmformerEncoder(nn.Module):
- output_lengths, with shape (B,), without containing the - output_lengths, with shape (B,), without containing the
right_context at the end. right_context at the end.
""" """
U = x.size(0) - self.right_context_length x, pos_emb = self.encoder_pos(x, pos_len=x.size(0), neg_len=x.size(0))
right_context = self._gen_right_context(x) U = x.size(0) - self.right_context_length
right_context, right_context_indexes = self._gen_right_context(x)
utterance_indexes = torch.arange(0, U)
utterance = x[:U] utterance = x[:U]
num_chunks = math.ceil(U / self.chunk_length)
memory_indexes = (
torch.arange(
self.chunk_length // 2,
(num_chunks - 1) * self.chunk_length,
self.chunk_length,
)
if num_chunks > 1
else torch.empty(0).to(dtype=utterance_indexes.dtype)
)
query_indexes = torch.cat(
[right_context_indexes, utterance_indexes]
).to(device=x.device)
key_indexes = torch.cat(
[memory_indexes, right_context_indexes, utterance_indexes]
).to(device=x.device)
# calculate relative position and flip sign
rel_pos = -(query_indexes.unsqueeze(1) - key_indexes.unsqueeze(0))
# shift to start from zero
rel_pos = rel_pos - rel_pos.min()
output_lengths = torch.clamp(lengths - self.right_context_length, min=0) output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
attention_mask = self._gen_attention_mask(utterance) attention_mask = self._gen_attention_mask(utterance)
@ -1394,9 +1554,11 @@ class EmformerEncoder(nn.Module):
output = utterance output = utterance
for layer in self.emformer_layers: for layer in self.emformer_layers:
output, right_context = layer( output, right_context = layer(
output, utterance=output,
right_context, right_context=right_context,
attention_mask, attention_mask=attention_mask,
pos_emb=pos_emb,
rel_pos=rel_pos,
padding_mask=padding_mask, padding_mask=padding_mask,
warmup=warmup, warmup=warmup,
) )
@ -1445,6 +1607,7 @@ class EmformerEncoder(nn.Module):
""" """
assert num_processed_frames.shape == (x.size(1),) assert num_processed_frames.shape == (x.size(1),)
# check the shapes of states
attn_caches = states[0] attn_caches = states[0]
assert len(attn_caches) == self.num_encoder_layers, len(attn_caches) assert len(attn_caches) == self.num_encoder_layers, len(attn_caches)
for i in range(len(attn_caches)): for i in range(len(attn_caches)):
@ -1473,6 +1636,11 @@ class EmformerEncoder(nn.Module):
self.cnn_module_kernel - 1, self.cnn_module_kernel - 1,
), conv_caches[i].shape ), conv_caches[i].shape
tot_past_length = self.memory_size * self.chunk_length
x, pos_emb = self.encoder_pos(
x, pos_len=x.size(0) + tot_past_length, neg_len=x.size(0)
)
right_context = x[-self.right_context_length :] right_context = x[-self.right_context_length :]
utterance = x[: -self.right_context_length] utterance = x[: -self.right_context_length]
output_lengths = torch.clamp(lengths - self.right_context_length, min=0) output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
@ -1504,6 +1672,36 @@ class EmformerEncoder(nn.Module):
dim=1, dim=1,
) )
# calculate relative position
memory_indexes = torch.arange(
self.chunk_length // 2, tot_past_length, self.chunk_length
)
left_context_indexes = torch.arange(
tot_past_length - self.left_context_length, tot_past_length
)
utterance_indexes = torch.arange(
tot_past_length, tot_past_length + utterance.size(0)
)
right_context_indexes = torch.arange(
tot_past_length + utterance.size(0),
tot_past_length + utterance.size(0) + right_context.size(0),
)
query_indexes = torch.cat(
[right_context_indexes, utterance_indexes]
).to(device=x.device)
key_indexes = torch.cat(
[
memory_indexes,
right_context_indexes,
left_context_indexes,
utterance_indexes,
]
).to(device=x.device)
# calculate relative position and flip sign
rel_pos = -(query_indexes.unsqueeze(1) - key_indexes.unsqueeze(0))
# shift to start from zero
rel_pos = rel_pos - rel_pos.min()
output = utterance output = utterance
output_attn_caches: List[List[torch.Tensor]] = [] output_attn_caches: List[List[torch.Tensor]] = []
output_conv_caches: List[torch.Tensor] = [] output_conv_caches: List[torch.Tensor] = []
@ -1514,8 +1712,10 @@ class EmformerEncoder(nn.Module):
output_attn_cache, output_attn_cache,
output_conv_cache, output_conv_cache,
) = layer.infer( ) = layer.infer(
output, utterance=output,
right_context, right_context=right_context,
pos_emb=pos_emb,
rel_pos=rel_pos,
padding_mask=padding_mask, padding_mask=padding_mask,
attn_cache=attn_caches[layer_idx], attn_cache=attn_caches[layer_idx],
conv_cache=conv_caches[layer_idx], conv_cache=conv_caches[layer_idx],
@ -1597,6 +1797,10 @@ class Emformer(EncoderInterface):
raise NotImplementedError( raise NotImplementedError(
"right_context_length must be 0 or a mutiple of subsampling_factor." # noqa "right_context_length must be 0 or a mutiple of subsampling_factor." # noqa
) )
if memory_size > 0 and memory_size * chunk_length < left_context_length:
raise NotImplementedError(
"memory_size * chunk_length must not be smaller than left_context_length." # noqa
)
# self.encoder_embed converts the input of shape (N, T, num_features) # self.encoder_embed converts the input of shape (N, T, num_features)
# to the shape (N, T//subsampling_factor, d_model). # to the shape (N, T//subsampling_factor, d_model).
@ -1822,3 +2026,119 @@ class Conv2dSubsampling(nn.Module):
x = self.out_norm(x) x = self.out_norm(x)
x = self.out_balancer(x) x = self.out_balancer(x)
return x return x
class RelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding module.
See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" # noqa
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py # noqa
Suppose:
i -> position of query,
j -> position of key(value),
we use positive relative position embedding when key(value) is to the
left of query(i.e., i > j) and negative embedding otherwise.
Args:
d_model: Embedding dimension.
dropout: Dropout rate.
max_len: Maximum input length.
"""
def __init__(
self, d_model: int, dropout: float, max_len: int = 5000
) -> None:
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
self.d_model = d_model
self.dropout = torch.nn.Dropout(p=dropout)
self.pe = None
self.pos_len = max_len
self.neg_len = max_len
self.gen_pe_positive()
self.gen_pe_negative()
def gen_pe_positive(self) -> None:
"""Generate the positive positional encodings."""
pe_positive = torch.zeros(self.pos_len, self.d_model)
position_positive = torch.arange(
0, self.pos_len, dtype=torch.float32
).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe_positive[:, 0::2] = torch.sin(position_positive * div_term)
pe_positive[:, 1::2] = torch.cos(position_positive * div_term)
# Reserve the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" # noqa
self.pe_positive = torch.flip(pe_positive, [0])
def gen_pe_negative(self) -> None:
"""Generate the negative positional encodings."""
# Suppose `i` means to the position of query vecotr and `j` means the
# position of key vector. We use positive relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_negative = torch.zeros(self.neg_len, self.d_model)
position_negative = torch.arange(
0, self.neg_len, dtype=torch.float32
).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe_negative[:, 0::2] = torch.sin(-1 * position_negative * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position_negative * div_term)
self.pe_negative = pe_negative
def get_pe(
self,
pos_len: int,
neg_len: int,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
"""Get positional encoding given positive length and negative length."""
if self.pe_positive.dtype != dtype or str(
self.pe_positive.device
) != str(device):
self.pe_positive = self.pe_positive.to(dtype=dtype, device=device)
if self.pe_negative.dtype != dtype or str(
self.pe_negative.device
) != str(device):
self.pe_negative = self.pe_negative.to(dtype=dtype, device=device)
pe = torch.cat(
[
self.pe_positive[self.pos_len - pos_len :],
self.pe_negative[1:neg_len],
],
dim=0,
)
return pe
def forward(
self,
x: torch.Tensor,
pos_len: int,
neg_len: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Scale input x and get positional encoding.
Args:
x (torch.Tensor): Input tensor (`*`).
Returns:
torch.Tensor:
Encoded tensor of shape (`*`).
torch.Tensor:
Position embedding of shape (pos_len + neg_len - 1, `*`).
"""
if pos_len > self.pos_len:
self.pos_len = pos_len
self.gen_pe_positive()
if neg_len > self.neg_len:
self.neg_len = neg_len
self.gen_pe_negative()
pos_emb = self.get_pe(pos_len, neg_len, x.device, x.dtype)
return self.dropout(x), self.dropout(pos_emb)

View File

@ -114,8 +114,12 @@ def test_state_stack_unstack():
for _ in range(num_encoder_layers) for _ in range(num_encoder_layers)
] ]
states = [attn_caches, conv_caches] states = [attn_caches, conv_caches]
x = torch.randn(batch_size, 23, num_features) x = torch.randn(
x_lens = torch.full((batch_size,), 23) batch_size, chunk_length + right_context_length + 3, num_features
)
x_lens = torch.full(
(batch_size,), chunk_length + right_context_length + 3
)
num_processed_frames = torch.full((batch_size,), 0) num_processed_frames = torch.full((batch_size,), 0)
y, y_lens, states = model.infer( y, y_lens, states = model.infer(
x, x_lens, num_processed_frames=num_processed_frames, states=states x, x_lens, num_processed_frames=num_processed_frames, states=states
@ -172,8 +176,10 @@ def test_torchscript_consistency_infer():
for _ in range(num_encoder_layers) for _ in range(num_encoder_layers)
] ]
states = [attn_caches, conv_caches] states = [attn_caches, conv_caches]
x = torch.randn(batch_size, 23, num_features) x = torch.randn(
x_lens = torch.full((batch_size,), 23) batch_size, chunk_length + right_context_length + 3, num_features
)
x_lens = torch.full((batch_size,), chunk_length + right_context_length + 3)
num_processed_frames = torch.full((batch_size,), 0) num_processed_frames = torch.full((batch_size,), 0)
y, y_lens, out_states = model.infer( y, y_lens, out_states = model.infer(
x, x_lens, num_processed_frames=num_processed_frames, states=states x, x_lens, num_processed_frames=num_processed_frames, states=states
@ -187,8 +193,38 @@ def test_torchscript_consistency_infer():
assert torch.allclose(y, sc_y) assert torch.allclose(y, sc_y)
def test_emformer_forward_shape():
num_features = 80
chunk_length = 32
encoder_dim = 512
num_encoder_layers = 2
kernel_size = 31
left_context_length = 32
right_context_length = 8
memory_size = 32
batch_size = 2
model = Emformer(
num_features=num_features,
chunk_length=chunk_length,
subsampling_factor=4,
d_model=encoder_dim,
num_encoder_layers=num_encoder_layers,
cnn_module_kernel=kernel_size,
left_context_length=left_context_length,
right_context_length=right_context_length,
memory_size=memory_size,
)
U = 2 * chunk_length
x = torch.randn(batch_size, U + right_context_length + 3, num_features)
x_lens = torch.full((batch_size,), U + right_context_length + 3)
output, output_lengths = model(x, x_lens)
assert output.shape == (batch_size, U >> 2, encoder_dim)
if __name__ == "__main__": if __name__ == "__main__":
test_convolution_module_forward() test_convolution_module_forward()
test_convolution_module_infer() test_convolution_module_infer()
test_state_stack_unstack() test_state_stack_unstack()
test_torchscript_consistency_infer() test_torchscript_consistency_infer()
test_emformer_forward_shape()