minor refactor of emformer codes

This commit is contained in:
yaozengwei 2022-05-10 11:15:44 +08:00
parent aff7c4ee3c
commit e3a29b17f3
2 changed files with 138 additions and 154 deletions

View File

@ -40,24 +40,6 @@ def _get_activation_module(activation: str) -> nn.Module:
raise ValueError(f"Unsupported activation {activation}") raise ValueError(f"Unsupported activation {activation}")
def _get_weight_init_gains(
weight_init_scale_strategy: Optional[str], num_layers: int
) -> List[Optional[float]]:
if weight_init_scale_strategy is None:
return [None for _ in range(num_layers)]
elif weight_init_scale_strategy == "depthwise":
return [
1.0 / math.sqrt(layer_idx + 1) for layer_idx in range(num_layers)
]
elif weight_init_scale_strategy == "constant":
return [1.0 / math.sqrt(2) for layer_idx in range(num_layers)]
else:
raise ValueError(
f"Unsupported weight_init_scale_strategy value"
f"{weight_init_scale_strategy}"
)
def _gen_attention_mask_block( def _gen_attention_mask_block(
col_widths: List[int], col_widths: List[int],
col_mask: List[bool], col_mask: List[bool],
@ -154,6 +136,8 @@ class EmformerAttention(nn.Module):
Embedding dimension. Embedding dimension.
nhead (int): nhead (int):
Number of attention heads in each Emformer layer. Number of attention heads in each Emformer layer.
dropout (float):
A Dropout layer on attn_output_weights. (Default: 0.0)
tanh_on_mem (bool, optional): tanh_on_mem (bool, optional):
If ``True``, applies tanh to memory elements. (Default: ``False``) If ``True``, applies tanh to memory elements. (Default: ``False``)
negative_inf (float, optional): negative_inf (float, optional):
@ -164,6 +148,7 @@ class EmformerAttention(nn.Module):
self, self,
embed_dim: int, embed_dim: int,
nhead: int, nhead: int,
dropout: float = 0.0,
tanh_on_mem: bool = False, tanh_on_mem: bool = False,
negative_inf: float = -1e8, negative_inf: float = -1e8,
): ):
@ -173,13 +158,14 @@ class EmformerAttention(nn.Module):
raise ValueError( raise ValueError(
f"embed_dim ({embed_dim}) is not a multiple of nhead ({nhead})." f"embed_dim ({embed_dim}) is not a multiple of nhead ({nhead})."
) )
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.nhead = nhead self.nhead = nhead
self.tanh_on_mem = tanh_on_mem self.tanh_on_mem = tanh_on_mem
self.negative_inf = negative_inf self.negative_inf = negative_inf
self.head_dim = embed_dim // nhead self.head_dim = embed_dim // nhead
self.dropout = dropout
self.scaling = self.head_dim ** -0.5 self.scaling = self.head_dim ** -0.5
self.emb_to_key_value = nn.Linear(embed_dim, 2 * embed_dim, bias=True) self.emb_to_key_value = nn.Linear(embed_dim, 2 * embed_dim, bias=True)
@ -262,6 +248,9 @@ class EmformerAttention(nn.Module):
attention_weights_float, dim=-1 attention_weights_float, dim=-1
).type_as(attention_weights) ).type_as(attention_weights)
attention_probs = nn.functional.dropout(
attention_probs, p=self.dropout, training=self.training
)
return attention_probs return attention_probs
def _rel_shift(self, x: torch.Tensor) -> torch.Tensor: def _rel_shift(self, x: torch.Tensor) -> torch.Tensor:
@ -311,12 +300,12 @@ class EmformerAttention(nn.Module):
KV: length of attention key and value. KV: length of attention key and value.
1) Concat right_context, utterance, summary, 1) Concat right_context, utterance, summary,
and compute query tensor with length Q = R + U + S. and compute query with length Q = R + U + S.
2) Concat memory, right_context, utterance, 2) Concat memory, right_context, utterance,
and compute key, value tensors with length KV = M + R + U; and compute key, value with length KV = M + R + U;
optionally with left_context_key and left_context_val (inference mode), also with left_context_key and left_context_val for infererence mode,
then KV = M + R + L + U. then KV = M + R + L + U.
3) Compute entire attention scores with query, key, and value, 3) Compute entire attention scores with above query, key, and value,
then apply attention_mask to get underlying chunk-wise attention scores. then apply attention_mask to get underlying chunk-wise attention scores.
Args: Args:
@ -336,13 +325,13 @@ class EmformerAttention(nn.Module):
pos_emb (torch.Tensor): pos_emb (torch.Tensor):
Position encoding embedding, with shape (PE, D). Position encoding embedding, with shape (PE, D).
For training mode, PE = 2 * U - 1; For training mode, PE = 2 * U - 1;
For infer mode, PE = L+2*U-1. For inference mode, PE = L + 2 * U - 1.
left_context_key (torch,Tensor, optional): left_context_key (torch,Tensor, optional):
Cached attention key of left context from preceding computation, Cached attention key of left context from preceding computation,
with shape (L, B, D). with shape (L, B, D). It is used for inference mode.
left_context_val (torch.Tensor, optional): left_context_val (torch.Tensor, optional):
Cached attention value of left context from preceding computation, Cached attention value of left context from preceding computation,
with shape (L, B, D). with shape (L, B, D). It is used for inference mode.
Returns: Returns:
A tuple containing 4 tensors: A tuple containing 4 tensors:
@ -355,23 +344,21 @@ class EmformerAttention(nn.Module):
R = right_context.size(0) R = right_context.size(0)
M = memory.size(0) M = memory.size(0)
# Compute query with [right context, utterance, summary]. # compute query with [right context, utterance, summary].
query = self.emb_to_query( query = self.emb_to_query(
torch.cat([right_context, utterance, summary]) torch.cat([right_context, utterance, summary])
) )
# Compute key and value with [mems, right context, utterance]. # compute key and value with [mems, right context, utterance].
key, value = self.emb_to_key_value( key, value = self.emb_to_key_value(
torch.cat([memory, right_context, utterance]) torch.cat([memory, right_context, utterance])
).chunk(chunks=2, dim=2) ).chunk(chunks=2, dim=2)
if left_context_key is not None and left_context_val is not None: if left_context_key is not None and left_context_val is not None:
# This is for inference mode. Now compute key and value with # compute key and value with
# [mems, right context, left context, uttrance] # [mems, right context, left context, uttrance]
key = torch.cat( key = torch.cat([key[: M + R], left_context_key, key[M + R :]])
[key[: M + R], left_context_key, key[M + R :]] # noqa
)
value = torch.cat( value = torch.cat(
[value[: M + R], left_context_val, value[M + R :]] # noqa [value[: M + R], left_context_val, value[M + R :]]
) )
Q = query.size(0) Q = query.size(0)
KV = key.size(0) KV = key.size(0)
@ -381,12 +368,14 @@ class EmformerAttention(nn.Module):
.view(KV, B * self.nhead, self.head_dim) .view(KV, B * self.nhead, self.head_dim)
.transpose(0, 1) .transpose(0, 1)
for tensor in [key, value] for tensor in [key, value]
] # (B * nhead, KV, head_dim) ] # both of shape (B * nhead, KV, head_dim)
reshaped_query = query.contiguous().view( reshaped_query = query.contiguous().view(
Q, B, self.nhead, self.head_dim Q, B, self.nhead, self.head_dim
) )
# compute attention matrix ac # compute attention score
# first compute attention matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # noqa
query_with_bais_u = ( query_with_bais_u = (
(reshaped_query + self.pos_bias_u) (reshaped_query + self.pos_bias_u)
.view(Q, B * self.nhead, self.head_dim) .view(Q, B * self.nhead, self.head_dim)
@ -396,7 +385,9 @@ class EmformerAttention(nn.Module):
query_with_bais_u, reshaped_key.transpose(1, 2) query_with_bais_u, reshaped_key.transpose(1, 2)
) # (B * nhead, Q, KV) ) # (B * nhead, Q, KV)
# compute attention matrix bd # second, compute attention matrix b and matrix d
# relative positional encoding is applied on the part of attention
# between query: [utterance] -> key, value: [left_context, utterance]
utterance_with_bais_v = ( utterance_with_bais_v = (
reshaped_query[R : R + U] + self.pos_bias_v reshaped_query[R : R + U] + self.pos_bias_v
).permute(1, 2, 0, 3) ).permute(1, 2, 0, 3)
@ -416,10 +407,10 @@ class EmformerAttention(nn.Module):
matrix_bd_utterance = torch.matmul( matrix_bd_utterance = torch.matmul(
utterance_with_bais_v, pos_emb.transpose(-2, -1) utterance_with_bais_v, pos_emb.transpose(-2, -1)
) # (B, nhead, U, PE) ) # (B, nhead, U, PE)
# rel-shift # rel-shift operation
matrix_bd_utterance = self._rel_shift( matrix_bd_utterance = self._rel_shift(matrix_bd_utterance)
matrix_bd_utterance # (B, nhead, U, U) for training mode;
) # (B, nhead, U, U or L + U) # (B, nhead, U, L + U) for inference mode.
matrix_bd_utterance = matrix_bd_utterance.contiguous().view( matrix_bd_utterance = matrix_bd_utterance.contiguous().view(
B * self.nhead, U, -1 B * self.nhead, U, -1
) )
@ -428,25 +419,25 @@ class EmformerAttention(nn.Module):
attention_weights = (matrix_ac + matrix_bd) * self.scaling attention_weights = (matrix_ac + matrix_bd) * self.scaling
# Compute padding mask # compute padding mask
if B == 1: if B == 1:
padding_mask = None padding_mask = None
else: else:
padding_mask = make_pad_mask(KV - U + lengths) padding_mask = make_pad_mask(KV - U + lengths)
# Compute attention probabilities. # compute attention probabilities
attention_probs = self._gen_attention_probs( attention_probs = self._gen_attention_probs(
attention_weights, attention_mask, padding_mask attention_weights, attention_mask, padding_mask
) )
# Compute attention. # compute attention outputs
attention = torch.bmm(attention_probs, reshaped_value) attention = torch.bmm(attention_probs, reshaped_value)
assert attention.shape == (B * self.nhead, Q, self.head_dim) assert attention.shape == (B * self.nhead, Q, self.head_dim)
attention = ( attention = (
attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim)
) )
# Apply output projection. # apply output projection
outputs = self.out_proj(attention) outputs = self.out_proj(attention)
output_right_context_utterance = outputs[: R + U] output_right_context_utterance = outputs[: R + U]
@ -487,7 +478,7 @@ class EmformerAttention(nn.Module):
right_context (torch.Tensor): right_context (torch.Tensor):
Right context frames, with shape (R, B, D). Right context frames, with shape (R, B, D).
summary (torch.Tensor): summary (torch.Tensor):
Summary elements, with shape (S, B, D). Summary elements with shape (S, B, D) or an empty tensor.
memory (torch.Tensor): memory (torch.Tensor):
Memory elements, with shape (M, B, D). Memory elements, with shape (M, B, D).
attention_mask (torch.Tensor): attention_mask (torch.Tensor):
@ -495,7 +486,7 @@ class EmformerAttention(nn.Module):
with shape (Q, KV), where Q = R + U + S, KV = M + R + U. with shape (Q, KV), where Q = R + U + S, KV = M + R + U.
pos_emb (torch.Tensor): pos_emb (torch.Tensor):
Position encoding embedding, with shape (PE, D). Position encoding embedding, with shape (PE, D).
For training mode, P = 2*U-1. where PE = 2 * U - 1.
Returns: Returns:
A tuple containing 2 tensors: A tuple containing 2 tensors:
@ -549,7 +540,7 @@ class EmformerAttention(nn.Module):
right_context (torch.Tensor): right_context (torch.Tensor):
Right context frames, with shape (R, B, D). Right context frames, with shape (R, B, D).
summary (torch.Tensor): summary (torch.Tensor):
Summary element, with shape (1, B, D), or empty. Summary element with shape (1, B, D), or an empty tensor.
memory (torch.Tensor): memory (torch.Tensor):
Memory elements, with shape (M, B, D). Memory elements, with shape (M, B, D).
left_context_key (torch,Tensor): left_context_key (torch,Tensor):
@ -571,19 +562,20 @@ class EmformerAttention(nn.Module):
- attention value of left context and utterance, which would be - attention value of left context and utterance, which would be
cached for next computation, with shape (L + U, B, D). cached for next computation, with shape (L + U, B, D).
""" """
U = utterance.size(0)
R = right_context.size(0)
L = left_context_key.size(0)
S = summary.size(0)
M = memory.size(0)
# query: [right context, utterance, summary] # query: [right context, utterance, summary]
Q = right_context.size(0) + utterance.size(0) + summary.size(0) Q = R + U + S
# key, value: [memory, right context, left context, uttrance] # key, value: [memory, right context, left context, uttrance]
KV = ( KV = M + R + L + U
memory.size(0)
+ right_context.size(0) # noqa
+ left_context_key.size(0) # noqa
+ utterance.size(0) # noqa
)
attention_mask = torch.zeros(Q, KV).to( attention_mask = torch.zeros(Q, KV).to(
dtype=torch.bool, device=utterance.device dtype=torch.bool, device=utterance.device
) )
# Disallow attention bettween the summary vector with the memory bank # disallow attention bettween the summary vector with the memory bank
attention_mask[-1, : memory.size(0)] = True attention_mask[-1, : memory.size(0)] = True
( (
output_right_context_utterance, output_right_context_utterance,
@ -601,12 +593,11 @@ class EmformerAttention(nn.Module):
left_context_key=left_context_key, left_context_key=left_context_key,
left_context_val=left_context_val, left_context_val=left_context_val,
) )
right_context_end_idx = memory.size(0) + right_context.size(0)
return ( return (
output_right_context_utterance, output_right_context_utterance,
output_memory, output_memory,
key[right_context_end_idx:], key[M + R :],
value[right_context_end_idx:], value[M + R :],
) )
@ -656,6 +647,7 @@ class EmformerLayer(nn.Module):
self.attention = EmformerAttention( self.attention = EmformerAttention(
embed_dim=d_model, embed_dim=d_model,
nhead=nhead, nhead=nhead,
dropout=dropout,
tanh_on_mem=tanh_on_mem, tanh_on_mem=tanh_on_mem,
negative_inf=negative_inf, negative_inf=negative_inf,
) )
@ -756,9 +748,9 @@ class EmformerLayer(nn.Module):
layer_norm_input = self.layer_norm_input( layer_norm_input = self.layer_norm_input(
torch.cat([right_context, utterance]) torch.cat([right_context, utterance])
) )
right_context_end_idx = right_context.size(0) R = right_context.size(0)
layer_norm_utterance = layer_norm_input[right_context_end_idx:] layer_norm_utterance = layer_norm_input[R:]
layer_norm_right_context = layer_norm_input[:right_context_end_idx] layer_norm_right_context = layer_norm_input[:R]
return layer_norm_utterance, layer_norm_right_context return layer_norm_utterance, layer_norm_right_context
def _apply_post_attention_ffn_layer_norm( def _apply_post_attention_ffn_layer_norm(
@ -768,18 +760,18 @@ class EmformerLayer(nn.Module):
right_context: torch.Tensor, right_context: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply feed forward and layer normalization after attention.""" """Apply feed forward and layer normalization after attention."""
# Apply residual connection between input and attention output. # apply residual connection between input and attention output.
result = self.dropout(output_right_context_utterance) + torch.cat( result = self.dropout(output_right_context_utterance) + torch.cat(
[right_context, utterance] [right_context, utterance]
) )
# Apply feedforward module and residual connection. # apply feedforward module and residual connection.
result = self.pos_ff(result) + result result = self.pos_ff(result) + result
# Apply layer normalization for output. # apply layer normalization for output.
result = self.layer_norm_output(result) result = self.layer_norm_output(result)
right_context_end_idx = right_context.size(0) R = right_context.size(0)
output_utterance = result[right_context_end_idx:] output_utterance = result[R:]
output_right_context = result[:right_context_end_idx] output_right_context = result[:R]
return output_utterance, output_right_context return output_utterance, output_right_context
def _apply_attention_forward( def _apply_attention_forward(
@ -796,7 +788,6 @@ class EmformerLayer(nn.Module):
raise ValueError( raise ValueError(
"attention_mask must be not None in non-infer mode. " "attention_mask must be not None in non-infer mode. "
) )
if self.use_memory: if self.use_memory:
summary = self.summary_op(utterance.permute(1, 2, 0)).permute( summary = self.summary_op(utterance.permute(1, 2, 0)).permute(
2, 0, 1 2, 0, 1
@ -851,8 +842,10 @@ class EmformerLayer(nn.Module):
summary = torch.empty(0).to( summary = torch.empty(0).to(
dtype=utterance.dtype, device=utterance.device dtype=utterance.dtype, device=utterance.device
) )
# pos_emb is of shape [PE, D], PE = L + 2 * U - 1, # pos_emb is of shape [PE, D], where PE = L + 2 * U - 1,
# the relative distance j - i of key(j) and query(i) is in range of [-(L + U - 1), (U - 1)] # noqa # for query of [utterance] (i), key-value [left_context, utterance] (j),
# the max relative distance i - j is L + U - 1
# the min relative distance i - j is -(U - 1)
L = left_context_key.size(0) # L <= left_context_length L = left_context_key.size(0) # L <= left_context_length
U = utterance.size(0) U = utterance.size(0)
PE = L + 2 * U - 1 PE = L + 2 * U - 1
@ -916,8 +909,8 @@ class EmformerLayer(nn.Module):
Attention mask for underlying attention module, Attention mask for underlying attention module,
with shape (Q, KV), where Q = R + U + S, KV = M + R + U. with shape (Q, KV), where Q = R + U + S, KV = M + R + U.
pos_emb (torch.Tensor): pos_emb (torch.Tensor):
Position encoding embedding, with shape (PE, D). Position encoding embedding, with shape (PE, D),
For training mode, P = 2*U-1. where PE = 2 * U - 1.
Returns: Returns:
A tuple containing 3 tensors: A tuple containing 3 tensors:
@ -987,8 +980,8 @@ class EmformerLayer(nn.Module):
List of tensors representing layer internal state generated in List of tensors representing layer internal state generated in
preceding computation. (default=None) preceding computation. (default=None)
pos_emb (torch.Tensor): pos_emb (torch.Tensor):
Position encoding embedding, with shape (PE, D). Position encoding embedding, with shape (PE, D),
For infer mode, PE = L+2*U-1. where PE = L + 2 * U - 1.
Returns: Returns:
(Tensor, Tensor, List[torch.Tensor], Tensor): (Tensor, Tensor, List[torch.Tensor], Tensor):
@ -1073,7 +1066,6 @@ class EmformerEncoder(nn.Module):
left_context_length: int = 0, left_context_length: int = 0,
right_context_length: int = 0, right_context_length: int = 0,
max_memory_size: int = 0, max_memory_size: int = 0,
weight_init_scale_strategy: str = "depthwise",
tanh_on_mem: bool = False, tanh_on_mem: bool = False,
negative_inf: float = -1e8, negative_inf: float = -1e8,
): ):
@ -1104,6 +1096,8 @@ class EmformerEncoder(nn.Module):
] ]
) )
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
self.left_context_length = left_context_length self.left_context_length = left_context_length
self.right_context_length = right_context_length self.right_context_length = right_context_length
self.chunk_length = chunk_length self.chunk_length = chunk_length
@ -1246,10 +1240,7 @@ class EmformerEncoder(nn.Module):
return attention_mask return attention_mask
def forward( def forward(
self, self, x: torch.Tensor, lengths: torch.Tensor
x: torch.Tensor,
lengths: torch.Tensor,
pos_emb: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass for training and non-streaming inference. """Forward pass for training and non-streaming inference.
@ -1265,9 +1256,6 @@ class EmformerEncoder(nn.Module):
With shape (B,) and i-th element representing number of valid With shape (B,) and i-th element representing number of valid
utterance frames for i-th batch element in x, which contains the utterance frames for i-th batch element in x, which contains the
right_context at the end. right_context at the end.
pos_emb (torch.Tensor):
Position encoding embedding, with shape (PE, D).
For training mode, P = 2*U-1.
Returns: Returns:
A tuple of 2 tensors: A tuple of 2 tensors:
@ -1275,8 +1263,11 @@ class EmformerEncoder(nn.Module):
- output_lengths, with shape (B,), without containing the - output_lengths, with shape (B,), without containing the
right_context at the end. right_context at the end.
""" """
U = x.size(0) - self.right_context_length
x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U)
right_context = self._gen_right_context(x) right_context = self._gen_right_context(x)
utterance = x[: x.size(0) - self.right_context_length] utterance = x[:U]
output_lengths = torch.clamp(lengths - self.right_context_length, min=0) output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
attention_mask = self._gen_attention_mask(utterance) attention_mask = self._gen_attention_mask(utterance)
memory = ( memory = (
@ -1286,6 +1277,7 @@ class EmformerEncoder(nn.Module):
if self.use_memory if self.use_memory
else torch.empty(0).to(dtype=x.dtype, device=x.device) else torch.empty(0).to(dtype=x.dtype, device=x.device)
) )
output = utterance output = utterance
for layer in self.emformer_layers: for layer in self.emformer_layers:
output, right_context, memory = layer( output, right_context, memory = layer(
@ -1304,7 +1296,6 @@ class EmformerEncoder(nn.Module):
self, self,
x: torch.Tensor, x: torch.Tensor,
lengths: torch.Tensor, lengths: torch.Tensor,
pos_emb: torch.Tensor,
states: Optional[List[List[torch.Tensor]]] = None, states: Optional[List[List[torch.Tensor]]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]: ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
"""Forward pass for streaming inference. """Forward pass for streaming inference.
@ -1325,9 +1316,6 @@ class EmformerEncoder(nn.Module):
Cached states from proceeding chunk's computation, where each Cached states from proceeding chunk's computation, where each
element (List[torch.Tensor]) corresponding to each emformer layer. element (List[torch.Tensor]) corresponding to each emformer layer.
(default: None) (default: None)
pos_emb (torch.Tensor):
Position encoding embedding, with shape (PE, D).
For infer mode, PE = L+2*U-1.
Returns: Returns:
(Tensor, Tensor, List[List[torch.Tensor]]): (Tensor, Tensor, List[List[torch.Tensor]]):
@ -1341,9 +1329,12 @@ class EmformerEncoder(nn.Module):
f"expected size of {self.chunk_length + self.right_context_length} " f"expected size of {self.chunk_length + self.right_context_length} "
f"for dimension 1 of x, but got {x.size(1)}." f"for dimension 1 of x, but got {x.size(1)}."
) )
right_context_start_idx = x.size(0) - self.right_context_length pos_len = self.chunk_length + self.left_context_length
right_context = x[right_context_start_idx:] neg_len = self.chunk_length
utterance = x[:right_context_start_idx] x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len)
right_context = x[self.chunk_length :]
utterance = x[: self.chunk_length]
output_lengths = torch.clamp(lengths - self.right_context_length, min=0) output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
memory = ( memory = (
self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
@ -1383,7 +1374,6 @@ class Emformer(EncoderInterface):
left_context_length: int = 0, left_context_length: int = 0,
right_context_length: int = 0, right_context_length: int = 0,
max_memory_size: int = 0, max_memory_size: int = 0,
weight_init_scale_strategy: str = "depthwise",
tanh_on_mem: bool = False, tanh_on_mem: bool = False,
negative_inf: float = -1e8, negative_inf: float = -1e8,
): ):
@ -1416,8 +1406,6 @@ class Emformer(EncoderInterface):
else: else:
self.encoder_embed = Conv2dSubsampling(num_features, d_model) self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
self.encoder = EmformerEncoder( self.encoder = EmformerEncoder(
chunk_length // 4, chunk_length // 4,
d_model, d_model,
@ -1429,7 +1417,6 @@ class Emformer(EncoderInterface):
left_context_length=left_context_length // 4, left_context_length=left_context_length // 4,
right_context_length=right_context_length // 4, right_context_length=right_context_length // 4,
max_memory_size=max_memory_size, max_memory_size=max_memory_size,
weight_init_scale_strategy=weight_init_scale_strategy,
tanh_on_mem=tanh_on_mem, tanh_on_mem=tanh_on_mem,
negative_inf=negative_inf, negative_inf=negative_inf,
) )
@ -1465,10 +1452,6 @@ class Emformer(EncoderInterface):
right_context at the end. right_context at the end.
""" """
x = self.encoder_embed(x) x = self.encoder_embed(x)
# TODO: The length computation in the encoder class should be moved here. # noqa
U = x.size(1) - self.right_context_length // 4
x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# Caution: We assume the subsampling factor is 4! # Caution: We assume the subsampling factor is 4!
@ -1477,7 +1460,7 @@ class Emformer(EncoderInterface):
x_lens = ((x_lens - 1) // 2 - 1) // 2 x_lens = ((x_lens - 1) // 2 - 1) // 2
assert x.size(0) == x_lens.max().item() assert x.size(0) == x_lens.max().item()
output, output_lengths = self.encoder(x, x_lens, pos_emb) # (T, N, C) output, output_lengths = self.encoder(x, x_lens) # (T, N, C)
logits = self.encoder_output_layer(output) logits = self.encoder_output_layer(output)
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
@ -1518,12 +1501,6 @@ class Emformer(EncoderInterface):
- updated states from current chunk's computation. - updated states from current chunk's computation.
""" """
x = self.encoder_embed(x) x = self.encoder_embed(x)
# TODO: The length computation in the encoder class should be moved here. # noqa
pos_len = self.chunk_length // 4 + self.left_context_length // 4
neg_len = self.chunk_length // 4
x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# Caution: We assume the subsampling factor is 4! # Caution: We assume the subsampling factor is 4!
@ -1533,7 +1510,7 @@ class Emformer(EncoderInterface):
assert x.size(0) == x_lens.max().item() assert x.size(0) == x_lens.max().item()
output, output_lengths, output_states = self.encoder.infer( output, output_lengths, output_states = self.encoder.infer(
x, x_lens, pos_emb, states x, x_lens, states
) # (T, N, C) ) # (T, N, C)
logits = self.encoder_output_layer(output) logits = self.encoder_output_layer(output)

View File

@ -199,7 +199,6 @@ def test_emformer_encoder_forward():
chunk_length = 4 chunk_length = 4
right_context_length = 2 right_context_length = 2
left_context_length = 2 left_context_length = 2
left_context_length = 2
num_chunks = 3 num_chunks = 3
U = num_chunks * chunk_length U = num_chunks * chunk_length
@ -223,10 +222,8 @@ def test_emformer_encoder_forward():
x = torch.randn(U + right_context_length, B, D) x = torch.randn(U + right_context_length, B, D)
lengths = torch.randint(1, U + right_context_length + 1, (B,)) lengths = torch.randint(1, U + right_context_length + 1, (B,))
lengths[0] = U + right_context_length lengths[0] = U + right_context_length
PE = 2 * U - 1
pos_emb = torch.randn(PE, D)
output, output_lengths = encoder(x, lengths, pos_emb) output, output_lengths = encoder(x, lengths)
assert output.shape == (U, B, D) assert output.shape == (U, B, D)
assert torch.equal( assert torch.equal(
output_lengths, torch.clamp(lengths - right_context_length, min=0) output_lengths, torch.clamp(lengths - right_context_length, min=0)
@ -266,11 +263,7 @@ def test_emformer_encoder_infer():
1, chunk_length + right_context_length + 1, (B,) 1, chunk_length + right_context_length + 1, (B,)
) )
lengths[0] = chunk_length + right_context_length lengths[0] = chunk_length + right_context_length
PE = left_context_length + 2 * chunk_length - 1 output, output_lengths, states = encoder.infer(x, lengths, states)
pos_emb = torch.randn(PE, D)
output, output_lengths, states = encoder.infer(
x, lengths, pos_emb, states
)
assert output.shape == (chunk_length, B, D) assert output.shape == (chunk_length, B, D)
assert torch.equal( assert torch.equal(
output_lengths, output_lengths,
@ -383,6 +376,7 @@ def test_emformer_infer():
def test_emformer_attention_forward_infer_consistency(): def test_emformer_attention_forward_infer_consistency():
# TODO: delete
from emformer import EmformerEncoder from emformer import EmformerEncoder
chunk_length = 4 chunk_length = 4
@ -474,7 +468,7 @@ def test_emformer_layer_forward_infer_consistency():
chunk_length = 4 chunk_length = 4
num_chunks = 3 num_chunks = 3
U = chunk_length * num_chunks U = chunk_length * num_chunks
L, R = 1, 2 left_context_length, right_context_length = 1, 2
D = 256 D = 256
num_encoder_layers = 1 num_encoder_layers = 1
memory_sizes = [0, 3] memory_sizes = [0, 3]
@ -485,18 +479,22 @@ def test_emformer_layer_forward_infer_consistency():
d_model=D, d_model=D,
dim_feedforward=1024, dim_feedforward=1024,
num_encoder_layers=num_encoder_layers, num_encoder_layers=num_encoder_layers,
left_context_length=L, left_context_length=left_context_length,
right_context_length=R, right_context_length=right_context_length,
max_memory_size=M, max_memory_size=M,
dropout=0.1, dropout=0.1,
) )
encoder.eval() encoder.eval()
encoder_layer = encoder.emformer_layers[0] encoder_layer = encoder.emformer_layers[0]
encoder_pos = encoder.encoder_pos
x = torch.randn(U + R, 1, D) x = torch.randn(U + right_context_length, 1, D)
# training mode with full utterance
x_forward, pos_emb = encoder_pos(x, U, U)
lengths = torch.tensor([U]) lengths = torch.tensor([U])
right_context = encoder._gen_right_context(x) right_context = encoder._gen_right_context(x_forward)
utterance = x[: x.size(0) - R] utterance = x_forward[:U]
attention_mask = encoder._gen_attention_mask(utterance) attention_mask = encoder._gen_attention_mask(utterance)
memory = ( memory = (
encoder.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ encoder.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[
@ -515,15 +513,20 @@ def test_emformer_layer_forward_infer_consistency():
right_context, right_context,
memory, memory,
attention_mask, attention_mask,
pos_emb,
) )
state = None state = None
for chunk_idx in range(num_chunks): for chunk_idx in range(num_chunks):
start_idx = chunk_idx * chunk_length start_idx = chunk_idx * chunk_length
end_idx = start_idx + chunk_length end_idx = start_idx + chunk_length
chunk = x[start_idx:end_idx] cur_x, pos_emb = encoder_pos(
chunk_right_context = x[end_idx : end_idx + R] # noqa x[start_idx : end_idx + right_context_length],
chunk_length = torch.tensor([chunk_length]) pos_len=chunk_length + left_context_length,
neg_len=chunk_length,
)
chunk = cur_x[:chunk_length]
chunk_right_context = cur_x[chunk_length:]
chunk_memory = ( chunk_memory = (
encoder.init_memory_op(chunk.permute(1, 2, 0)).permute(2, 0, 1) encoder.init_memory_op(chunk.permute(1, 2, 0)).permute(2, 0, 1)
if encoder.use_memory if encoder.use_memory
@ -536,9 +539,10 @@ def test_emformer_layer_forward_infer_consistency():
state, state,
) = encoder_layer.infer( ) = encoder_layer.infer(
chunk, chunk,
chunk_length, torch.tensor([chunk_length]),
chunk_right_context, chunk_right_context,
chunk_memory, chunk_memory,
pos_emb,
state, state,
) )
forward_output_chunk = forward_output_utterance[start_idx:end_idx] forward_output_chunk = forward_output_utterance[start_idx:end_idx]
@ -551,7 +555,7 @@ def test_emformer_layer_forward_infer_consistency():
def test_emformer_encoder_forward_infer_consistency(): def test_emformer_encoder_forward_infer_consistency():
from emformer import EmformerEncoder, RelPositionalEncoding from emformer import EmformerEncoder
chunk_length = 4 chunk_length = 4
num_chunks = 3 num_chunks = 3
@ -573,28 +577,22 @@ def test_emformer_encoder_forward_infer_consistency():
dropout=0.1, dropout=0.1,
) )
encoder.eval() encoder.eval()
encoder_pos = RelPositionalEncoding(D, dropout_rate=0)
x = torch.randn(U + right_context_length, 1, D) x = torch.randn(U + right_context_length, 1, D)
lengths = torch.tensor([U + right_context_length]) lengths = torch.tensor([U + right_context_length])
_, pos_emb = encoder_pos(x, U, U)
forward_output, forward_output_lengths = encoder(x, lengths, pos_emb) # training mode with full utterance
forward_output, forward_output_lengths = encoder(x, lengths)
# streaming inference mode with individual chunks
states = None states = None
_, pos_emb = encoder_pos(
x, chunk_length + left_context_length, chunk_length
)
for chunk_idx in range(num_chunks): for chunk_idx in range(num_chunks):
start_idx = chunk_idx * chunk_length start_idx = chunk_idx * chunk_length
end_idx = start_idx + chunk_length end_idx = start_idx + chunk_length
chunk = x[start_idx : end_idx + right_context_length] # noqa chunk = x[start_idx : end_idx + right_context_length] # noqa
chunk_length = torch.tensor([chunk_length]) chunk_length = torch.tensor([chunk_length])
infer_output_chunk, infer_output_lengths, states = encoder.infer( infer_output_chunk, infer_output_lengths, states = encoder.infer(
chunk, chunk, chunk_length, states
chunk_length,
pos_emb,
states,
) )
forward_output_chunk = forward_output[start_idx:end_idx] forward_output_chunk = forward_output[start_idx:end_idx]
assert torch.allclose( assert torch.allclose(
@ -615,7 +613,7 @@ def test_emformer_infer_batch_single_consistency():
chunk_length = 8 chunk_length = 8
num_chunks = 3 num_chunks = 3
U = num_chunks * chunk_length U = num_chunks * chunk_length
L, R = 128, 4 left_context_length, right_context_length = 128, 4
B, D = 2, 256 B, D = 2, 256
num_encoder_layers = 2 num_encoder_layers = 2
for use_memory in [True, False]: for use_memory in [True, False]:
@ -630,8 +628,8 @@ def test_emformer_infer_batch_single_consistency():
subsampling_factor=4, subsampling_factor=4,
d_model=D, d_model=D,
num_encoder_layers=num_encoder_layers, num_encoder_layers=num_encoder_layers,
left_context_length=L, left_context_length=left_context_length,
right_context_length=R, right_context_length=right_context_length,
max_memory_size=M, max_memory_size=M,
vgg_frontend=False, vgg_frontend=False,
) )
@ -689,20 +687,25 @@ def test_emformer_infer_batch_single_consistency():
], ],
) )
x = torch.randn(B, U + R + 3, num_features) x = torch.randn(B, U + right_context_length + 3, num_features)
# batch-wise inference
batch_logits = [] batch_logits = []
batch_states = [] batch_states = []
states = None states = None
for chunk_idx in range(num_chunks): for chunk_idx in range(num_chunks):
start_idx = chunk_idx * chunk_length start_idx = chunk_idx * chunk_length
end_idx = start_idx + chunk_length end_idx = start_idx + chunk_length
chunk = x[:, start_idx : end_idx + R + 3] # noqa chunk = x[:, start_idx : end_idx + right_context_length + 3] # noqa
lengths = torch.tensor([chunk_length + R + 3]).expand(B) lengths = torch.tensor(
[chunk_length + right_context_length + 3]
).expand(B)
logits, output_lengths, states = model.infer(chunk, lengths, states) logits, output_lengths, states = model.infer(chunk, lengths, states)
batch_logits.append(logits) batch_logits.append(logits)
batch_states.append(save_states(states)) batch_states.append(save_states(states))
batch_logits = torch.cat(batch_logits, dim=1) batch_logits = torch.cat(batch_logits, dim=1)
# single-wise inference
single_logits = [] single_logits = []
for sample_idx in range(B): for sample_idx in range(B):
sample = x[sample_idx : sample_idx + 1] # noqa sample = x[sample_idx : sample_idx + 1] # noqa
@ -711,17 +714,21 @@ def test_emformer_infer_batch_single_consistency():
for chunk_idx in range(num_chunks): for chunk_idx in range(num_chunks):
start_idx = chunk_idx * chunk_length start_idx = chunk_idx * chunk_length
end_idx = start_idx + chunk_length end_idx = start_idx + chunk_length
chunk = sample[:, start_idx : end_idx + R + 3] # noqa chunk = sample[
lengths = torch.tensor([chunk_length + R + 3]) :, start_idx : end_idx + right_context_length + 3
]
lengths = torch.tensor(
[chunk_length + right_context_length + 3]
)
logits, output_lengths, states = model.infer( logits, output_lengths, states = model.infer(
chunk, lengths, states chunk, lengths, states
) )
chunk_logits.append(logits) chunk_logits.append(logits)
assert_states_equal(batch_states[chunk_idx], states, sample_idx) assert_states_equal(batch_states[chunk_idx], states, sample_idx)
chunk_logits = torch.cat(chunk_logits, dim=1) chunk_logits = torch.cat(chunk_logits, dim=1)
single_logits.append(chunk_logits) single_logits.append(chunk_logits)
single_logits = torch.cat(single_logits, dim=0) single_logits = torch.cat(single_logits, dim=0)
assert torch.allclose(batch_logits, single_logits, atol=1e-5, rtol=0.0) assert torch.allclose(batch_logits, single_logits, atol=1e-5, rtol=0.0)
@ -734,7 +741,7 @@ def test_emformer_infer_states_stack():
output_dim = 1000 output_dim = 1000
chunk_length = 8 chunk_length = 8
U = chunk_length U = chunk_length
L, R = 128, 4 left_context_length, right_context_length = 128, 4
B, D = 2, 256 B, D = 2, 256
num_encoder_layers = 2 num_encoder_layers = 2
for use_memory in [True, False]: for use_memory in [True, False]:
@ -749,14 +756,14 @@ def test_emformer_infer_states_stack():
subsampling_factor=4, subsampling_factor=4,
d_model=D, d_model=D,
num_encoder_layers=num_encoder_layers, num_encoder_layers=num_encoder_layers,
left_context_length=L, left_context_length=left_context_length,
right_context_length=R, right_context_length=right_context_length,
max_memory_size=M, max_memory_size=M,
vgg_frontend=False, vgg_frontend=False,
) )
x = torch.randn(B, U + R + 3, num_features) x = torch.randn(B, U + right_context_length + 3, num_features)
x_lens = torch.full((B,), U + R + 3) x_lens = torch.full((B,), U + right_context_length + 3)
logits, output_lengths, states = model.infer( logits, output_lengths, states = model.infer(
x, x,
x_lens, x_lens,
@ -790,8 +797,8 @@ if __name__ == "__main__":
test_emformer_forward() test_emformer_forward()
test_emformer_infer() test_emformer_infer()
# test_emformer_attention_forward_infer_consistency() # test_emformer_attention_forward_infer_consistency()
# test_emformer_layer_forward_infer_consistency() test_emformer_layer_forward_infer_consistency()
test_emformer_encoder_forward_infer_consistency() test_emformer_encoder_forward_infer_consistency()
# test_emformer_infer_batch_single_consistency() test_emformer_infer_batch_single_consistency()
# test_emformer_infer_states_stack() test_emformer_infer_states_stack()
test_rel_positional_encoding() test_rel_positional_encoding()