mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 15:14:18 +00:00
minor refactor of emformer codes
This commit is contained in:
parent
aff7c4ee3c
commit
e3a29b17f3
@ -40,24 +40,6 @@ def _get_activation_module(activation: str) -> nn.Module:
|
|||||||
raise ValueError(f"Unsupported activation {activation}")
|
raise ValueError(f"Unsupported activation {activation}")
|
||||||
|
|
||||||
|
|
||||||
def _get_weight_init_gains(
|
|
||||||
weight_init_scale_strategy: Optional[str], num_layers: int
|
|
||||||
) -> List[Optional[float]]:
|
|
||||||
if weight_init_scale_strategy is None:
|
|
||||||
return [None for _ in range(num_layers)]
|
|
||||||
elif weight_init_scale_strategy == "depthwise":
|
|
||||||
return [
|
|
||||||
1.0 / math.sqrt(layer_idx + 1) for layer_idx in range(num_layers)
|
|
||||||
]
|
|
||||||
elif weight_init_scale_strategy == "constant":
|
|
||||||
return [1.0 / math.sqrt(2) for layer_idx in range(num_layers)]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported weight_init_scale_strategy value"
|
|
||||||
f"{weight_init_scale_strategy}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _gen_attention_mask_block(
|
def _gen_attention_mask_block(
|
||||||
col_widths: List[int],
|
col_widths: List[int],
|
||||||
col_mask: List[bool],
|
col_mask: List[bool],
|
||||||
@ -154,6 +136,8 @@ class EmformerAttention(nn.Module):
|
|||||||
Embedding dimension.
|
Embedding dimension.
|
||||||
nhead (int):
|
nhead (int):
|
||||||
Number of attention heads in each Emformer layer.
|
Number of attention heads in each Emformer layer.
|
||||||
|
dropout (float):
|
||||||
|
A Dropout layer on attn_output_weights. (Default: 0.0)
|
||||||
tanh_on_mem (bool, optional):
|
tanh_on_mem (bool, optional):
|
||||||
If ``True``, applies tanh to memory elements. (Default: ``False``)
|
If ``True``, applies tanh to memory elements. (Default: ``False``)
|
||||||
negative_inf (float, optional):
|
negative_inf (float, optional):
|
||||||
@ -164,6 +148,7 @@ class EmformerAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
nhead: int,
|
nhead: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
tanh_on_mem: bool = False,
|
tanh_on_mem: bool = False,
|
||||||
negative_inf: float = -1e8,
|
negative_inf: float = -1e8,
|
||||||
):
|
):
|
||||||
@ -173,13 +158,14 @@ class EmformerAttention(nn.Module):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"embed_dim ({embed_dim}) is not a multiple of nhead ({nhead})."
|
f"embed_dim ({embed_dim}) is not a multiple of nhead ({nhead})."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
self.nhead = nhead
|
self.nhead = nhead
|
||||||
self.tanh_on_mem = tanh_on_mem
|
self.tanh_on_mem = tanh_on_mem
|
||||||
self.negative_inf = negative_inf
|
self.negative_inf = negative_inf
|
||||||
self.head_dim = embed_dim // nhead
|
self.head_dim = embed_dim // nhead
|
||||||
|
|
||||||
|
self.dropout = dropout
|
||||||
|
|
||||||
self.scaling = self.head_dim ** -0.5
|
self.scaling = self.head_dim ** -0.5
|
||||||
|
|
||||||
self.emb_to_key_value = nn.Linear(embed_dim, 2 * embed_dim, bias=True)
|
self.emb_to_key_value = nn.Linear(embed_dim, 2 * embed_dim, bias=True)
|
||||||
@ -262,6 +248,9 @@ class EmformerAttention(nn.Module):
|
|||||||
attention_weights_float, dim=-1
|
attention_weights_float, dim=-1
|
||||||
).type_as(attention_weights)
|
).type_as(attention_weights)
|
||||||
|
|
||||||
|
attention_probs = nn.functional.dropout(
|
||||||
|
attention_probs, p=self.dropout, training=self.training
|
||||||
|
)
|
||||||
return attention_probs
|
return attention_probs
|
||||||
|
|
||||||
def _rel_shift(self, x: torch.Tensor) -> torch.Tensor:
|
def _rel_shift(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@ -311,12 +300,12 @@ class EmformerAttention(nn.Module):
|
|||||||
KV: length of attention key and value.
|
KV: length of attention key and value.
|
||||||
|
|
||||||
1) Concat right_context, utterance, summary,
|
1) Concat right_context, utterance, summary,
|
||||||
and compute query tensor with length Q = R + U + S.
|
and compute query with length Q = R + U + S.
|
||||||
2) Concat memory, right_context, utterance,
|
2) Concat memory, right_context, utterance,
|
||||||
and compute key, value tensors with length KV = M + R + U;
|
and compute key, value with length KV = M + R + U;
|
||||||
optionally with left_context_key and left_context_val (inference mode),
|
also with left_context_key and left_context_val for infererence mode,
|
||||||
then KV = M + R + L + U.
|
then KV = M + R + L + U.
|
||||||
3) Compute entire attention scores with query, key, and value,
|
3) Compute entire attention scores with above query, key, and value,
|
||||||
then apply attention_mask to get underlying chunk-wise attention scores.
|
then apply attention_mask to get underlying chunk-wise attention scores.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -336,13 +325,13 @@ class EmformerAttention(nn.Module):
|
|||||||
pos_emb (torch.Tensor):
|
pos_emb (torch.Tensor):
|
||||||
Position encoding embedding, with shape (PE, D).
|
Position encoding embedding, with shape (PE, D).
|
||||||
For training mode, PE = 2 * U - 1;
|
For training mode, PE = 2 * U - 1;
|
||||||
For infer mode, PE = L+2*U-1.
|
For inference mode, PE = L + 2 * U - 1.
|
||||||
left_context_key (torch,Tensor, optional):
|
left_context_key (torch,Tensor, optional):
|
||||||
Cached attention key of left context from preceding computation,
|
Cached attention key of left context from preceding computation,
|
||||||
with shape (L, B, D).
|
with shape (L, B, D). It is used for inference mode.
|
||||||
left_context_val (torch.Tensor, optional):
|
left_context_val (torch.Tensor, optional):
|
||||||
Cached attention value of left context from preceding computation,
|
Cached attention value of left context from preceding computation,
|
||||||
with shape (L, B, D).
|
with shape (L, B, D). It is used for inference mode.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple containing 4 tensors:
|
A tuple containing 4 tensors:
|
||||||
@ -355,23 +344,21 @@ class EmformerAttention(nn.Module):
|
|||||||
R = right_context.size(0)
|
R = right_context.size(0)
|
||||||
M = memory.size(0)
|
M = memory.size(0)
|
||||||
|
|
||||||
# Compute query with [right context, utterance, summary].
|
# compute query with [right context, utterance, summary].
|
||||||
query = self.emb_to_query(
|
query = self.emb_to_query(
|
||||||
torch.cat([right_context, utterance, summary])
|
torch.cat([right_context, utterance, summary])
|
||||||
)
|
)
|
||||||
# Compute key and value with [mems, right context, utterance].
|
# compute key and value with [mems, right context, utterance].
|
||||||
key, value = self.emb_to_key_value(
|
key, value = self.emb_to_key_value(
|
||||||
torch.cat([memory, right_context, utterance])
|
torch.cat([memory, right_context, utterance])
|
||||||
).chunk(chunks=2, dim=2)
|
).chunk(chunks=2, dim=2)
|
||||||
|
|
||||||
if left_context_key is not None and left_context_val is not None:
|
if left_context_key is not None and left_context_val is not None:
|
||||||
# This is for inference mode. Now compute key and value with
|
# compute key and value with
|
||||||
# [mems, right context, left context, uttrance]
|
# [mems, right context, left context, uttrance]
|
||||||
key = torch.cat(
|
key = torch.cat([key[: M + R], left_context_key, key[M + R :]])
|
||||||
[key[: M + R], left_context_key, key[M + R :]] # noqa
|
|
||||||
)
|
|
||||||
value = torch.cat(
|
value = torch.cat(
|
||||||
[value[: M + R], left_context_val, value[M + R :]] # noqa
|
[value[: M + R], left_context_val, value[M + R :]]
|
||||||
)
|
)
|
||||||
Q = query.size(0)
|
Q = query.size(0)
|
||||||
KV = key.size(0)
|
KV = key.size(0)
|
||||||
@ -381,12 +368,14 @@ class EmformerAttention(nn.Module):
|
|||||||
.view(KV, B * self.nhead, self.head_dim)
|
.view(KV, B * self.nhead, self.head_dim)
|
||||||
.transpose(0, 1)
|
.transpose(0, 1)
|
||||||
for tensor in [key, value]
|
for tensor in [key, value]
|
||||||
] # (B * nhead, KV, head_dim)
|
] # both of shape (B * nhead, KV, head_dim)
|
||||||
reshaped_query = query.contiguous().view(
|
reshaped_query = query.contiguous().view(
|
||||||
Q, B, self.nhead, self.head_dim
|
Q, B, self.nhead, self.head_dim
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute attention matrix ac
|
# compute attention score
|
||||||
|
# first compute attention matrix a and matrix c
|
||||||
|
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # noqa
|
||||||
query_with_bais_u = (
|
query_with_bais_u = (
|
||||||
(reshaped_query + self.pos_bias_u)
|
(reshaped_query + self.pos_bias_u)
|
||||||
.view(Q, B * self.nhead, self.head_dim)
|
.view(Q, B * self.nhead, self.head_dim)
|
||||||
@ -396,7 +385,9 @@ class EmformerAttention(nn.Module):
|
|||||||
query_with_bais_u, reshaped_key.transpose(1, 2)
|
query_with_bais_u, reshaped_key.transpose(1, 2)
|
||||||
) # (B * nhead, Q, KV)
|
) # (B * nhead, Q, KV)
|
||||||
|
|
||||||
# compute attention matrix bd
|
# second, compute attention matrix b and matrix d
|
||||||
|
# relative positional encoding is applied on the part of attention
|
||||||
|
# between query: [utterance] -> key, value: [left_context, utterance]
|
||||||
utterance_with_bais_v = (
|
utterance_with_bais_v = (
|
||||||
reshaped_query[R : R + U] + self.pos_bias_v
|
reshaped_query[R : R + U] + self.pos_bias_v
|
||||||
).permute(1, 2, 0, 3)
|
).permute(1, 2, 0, 3)
|
||||||
@ -416,10 +407,10 @@ class EmformerAttention(nn.Module):
|
|||||||
matrix_bd_utterance = torch.matmul(
|
matrix_bd_utterance = torch.matmul(
|
||||||
utterance_with_bais_v, pos_emb.transpose(-2, -1)
|
utterance_with_bais_v, pos_emb.transpose(-2, -1)
|
||||||
) # (B, nhead, U, PE)
|
) # (B, nhead, U, PE)
|
||||||
# rel-shift
|
# rel-shift operation
|
||||||
matrix_bd_utterance = self._rel_shift(
|
matrix_bd_utterance = self._rel_shift(matrix_bd_utterance)
|
||||||
matrix_bd_utterance
|
# (B, nhead, U, U) for training mode;
|
||||||
) # (B, nhead, U, U or L + U)
|
# (B, nhead, U, L + U) for inference mode.
|
||||||
matrix_bd_utterance = matrix_bd_utterance.contiguous().view(
|
matrix_bd_utterance = matrix_bd_utterance.contiguous().view(
|
||||||
B * self.nhead, U, -1
|
B * self.nhead, U, -1
|
||||||
)
|
)
|
||||||
@ -428,25 +419,25 @@ class EmformerAttention(nn.Module):
|
|||||||
|
|
||||||
attention_weights = (matrix_ac + matrix_bd) * self.scaling
|
attention_weights = (matrix_ac + matrix_bd) * self.scaling
|
||||||
|
|
||||||
# Compute padding mask
|
# compute padding mask
|
||||||
if B == 1:
|
if B == 1:
|
||||||
padding_mask = None
|
padding_mask = None
|
||||||
else:
|
else:
|
||||||
padding_mask = make_pad_mask(KV - U + lengths)
|
padding_mask = make_pad_mask(KV - U + lengths)
|
||||||
|
|
||||||
# Compute attention probabilities.
|
# compute attention probabilities
|
||||||
attention_probs = self._gen_attention_probs(
|
attention_probs = self._gen_attention_probs(
|
||||||
attention_weights, attention_mask, padding_mask
|
attention_weights, attention_mask, padding_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute attention.
|
# compute attention outputs
|
||||||
attention = torch.bmm(attention_probs, reshaped_value)
|
attention = torch.bmm(attention_probs, reshaped_value)
|
||||||
assert attention.shape == (B * self.nhead, Q, self.head_dim)
|
assert attention.shape == (B * self.nhead, Q, self.head_dim)
|
||||||
attention = (
|
attention = (
|
||||||
attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim)
|
attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply output projection.
|
# apply output projection
|
||||||
outputs = self.out_proj(attention)
|
outputs = self.out_proj(attention)
|
||||||
|
|
||||||
output_right_context_utterance = outputs[: R + U]
|
output_right_context_utterance = outputs[: R + U]
|
||||||
@ -487,7 +478,7 @@ class EmformerAttention(nn.Module):
|
|||||||
right_context (torch.Tensor):
|
right_context (torch.Tensor):
|
||||||
Right context frames, with shape (R, B, D).
|
Right context frames, with shape (R, B, D).
|
||||||
summary (torch.Tensor):
|
summary (torch.Tensor):
|
||||||
Summary elements, with shape (S, B, D).
|
Summary elements with shape (S, B, D) or an empty tensor.
|
||||||
memory (torch.Tensor):
|
memory (torch.Tensor):
|
||||||
Memory elements, with shape (M, B, D).
|
Memory elements, with shape (M, B, D).
|
||||||
attention_mask (torch.Tensor):
|
attention_mask (torch.Tensor):
|
||||||
@ -495,7 +486,7 @@ class EmformerAttention(nn.Module):
|
|||||||
with shape (Q, KV), where Q = R + U + S, KV = M + R + U.
|
with shape (Q, KV), where Q = R + U + S, KV = M + R + U.
|
||||||
pos_emb (torch.Tensor):
|
pos_emb (torch.Tensor):
|
||||||
Position encoding embedding, with shape (PE, D).
|
Position encoding embedding, with shape (PE, D).
|
||||||
For training mode, P = 2*U-1.
|
where PE = 2 * U - 1.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple containing 2 tensors:
|
A tuple containing 2 tensors:
|
||||||
@ -549,7 +540,7 @@ class EmformerAttention(nn.Module):
|
|||||||
right_context (torch.Tensor):
|
right_context (torch.Tensor):
|
||||||
Right context frames, with shape (R, B, D).
|
Right context frames, with shape (R, B, D).
|
||||||
summary (torch.Tensor):
|
summary (torch.Tensor):
|
||||||
Summary element, with shape (1, B, D), or empty.
|
Summary element with shape (1, B, D), or an empty tensor.
|
||||||
memory (torch.Tensor):
|
memory (torch.Tensor):
|
||||||
Memory elements, with shape (M, B, D).
|
Memory elements, with shape (M, B, D).
|
||||||
left_context_key (torch,Tensor):
|
left_context_key (torch,Tensor):
|
||||||
@ -571,19 +562,20 @@ class EmformerAttention(nn.Module):
|
|||||||
- attention value of left context and utterance, which would be
|
- attention value of left context and utterance, which would be
|
||||||
cached for next computation, with shape (L + U, B, D).
|
cached for next computation, with shape (L + U, B, D).
|
||||||
"""
|
"""
|
||||||
|
U = utterance.size(0)
|
||||||
|
R = right_context.size(0)
|
||||||
|
L = left_context_key.size(0)
|
||||||
|
S = summary.size(0)
|
||||||
|
M = memory.size(0)
|
||||||
|
|
||||||
# query: [right context, utterance, summary]
|
# query: [right context, utterance, summary]
|
||||||
Q = right_context.size(0) + utterance.size(0) + summary.size(0)
|
Q = R + U + S
|
||||||
# key, value: [memory, right context, left context, uttrance]
|
# key, value: [memory, right context, left context, uttrance]
|
||||||
KV = (
|
KV = M + R + L + U
|
||||||
memory.size(0)
|
|
||||||
+ right_context.size(0) # noqa
|
|
||||||
+ left_context_key.size(0) # noqa
|
|
||||||
+ utterance.size(0) # noqa
|
|
||||||
)
|
|
||||||
attention_mask = torch.zeros(Q, KV).to(
|
attention_mask = torch.zeros(Q, KV).to(
|
||||||
dtype=torch.bool, device=utterance.device
|
dtype=torch.bool, device=utterance.device
|
||||||
)
|
)
|
||||||
# Disallow attention bettween the summary vector with the memory bank
|
# disallow attention bettween the summary vector with the memory bank
|
||||||
attention_mask[-1, : memory.size(0)] = True
|
attention_mask[-1, : memory.size(0)] = True
|
||||||
(
|
(
|
||||||
output_right_context_utterance,
|
output_right_context_utterance,
|
||||||
@ -601,12 +593,11 @@ class EmformerAttention(nn.Module):
|
|||||||
left_context_key=left_context_key,
|
left_context_key=left_context_key,
|
||||||
left_context_val=left_context_val,
|
left_context_val=left_context_val,
|
||||||
)
|
)
|
||||||
right_context_end_idx = memory.size(0) + right_context.size(0)
|
|
||||||
return (
|
return (
|
||||||
output_right_context_utterance,
|
output_right_context_utterance,
|
||||||
output_memory,
|
output_memory,
|
||||||
key[right_context_end_idx:],
|
key[M + R :],
|
||||||
value[right_context_end_idx:],
|
value[M + R :],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -656,6 +647,7 @@ class EmformerLayer(nn.Module):
|
|||||||
self.attention = EmformerAttention(
|
self.attention = EmformerAttention(
|
||||||
embed_dim=d_model,
|
embed_dim=d_model,
|
||||||
nhead=nhead,
|
nhead=nhead,
|
||||||
|
dropout=dropout,
|
||||||
tanh_on_mem=tanh_on_mem,
|
tanh_on_mem=tanh_on_mem,
|
||||||
negative_inf=negative_inf,
|
negative_inf=negative_inf,
|
||||||
)
|
)
|
||||||
@ -756,9 +748,9 @@ class EmformerLayer(nn.Module):
|
|||||||
layer_norm_input = self.layer_norm_input(
|
layer_norm_input = self.layer_norm_input(
|
||||||
torch.cat([right_context, utterance])
|
torch.cat([right_context, utterance])
|
||||||
)
|
)
|
||||||
right_context_end_idx = right_context.size(0)
|
R = right_context.size(0)
|
||||||
layer_norm_utterance = layer_norm_input[right_context_end_idx:]
|
layer_norm_utterance = layer_norm_input[R:]
|
||||||
layer_norm_right_context = layer_norm_input[:right_context_end_idx]
|
layer_norm_right_context = layer_norm_input[:R]
|
||||||
return layer_norm_utterance, layer_norm_right_context
|
return layer_norm_utterance, layer_norm_right_context
|
||||||
|
|
||||||
def _apply_post_attention_ffn_layer_norm(
|
def _apply_post_attention_ffn_layer_norm(
|
||||||
@ -768,18 +760,18 @@ class EmformerLayer(nn.Module):
|
|||||||
right_context: torch.Tensor,
|
right_context: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Apply feed forward and layer normalization after attention."""
|
"""Apply feed forward and layer normalization after attention."""
|
||||||
# Apply residual connection between input and attention output.
|
# apply residual connection between input and attention output.
|
||||||
result = self.dropout(output_right_context_utterance) + torch.cat(
|
result = self.dropout(output_right_context_utterance) + torch.cat(
|
||||||
[right_context, utterance]
|
[right_context, utterance]
|
||||||
)
|
)
|
||||||
# Apply feedforward module and residual connection.
|
# apply feedforward module and residual connection.
|
||||||
result = self.pos_ff(result) + result
|
result = self.pos_ff(result) + result
|
||||||
# Apply layer normalization for output.
|
# apply layer normalization for output.
|
||||||
result = self.layer_norm_output(result)
|
result = self.layer_norm_output(result)
|
||||||
|
|
||||||
right_context_end_idx = right_context.size(0)
|
R = right_context.size(0)
|
||||||
output_utterance = result[right_context_end_idx:]
|
output_utterance = result[R:]
|
||||||
output_right_context = result[:right_context_end_idx]
|
output_right_context = result[:R]
|
||||||
return output_utterance, output_right_context
|
return output_utterance, output_right_context
|
||||||
|
|
||||||
def _apply_attention_forward(
|
def _apply_attention_forward(
|
||||||
@ -796,7 +788,6 @@ class EmformerLayer(nn.Module):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"attention_mask must be not None in non-infer mode. "
|
"attention_mask must be not None in non-infer mode. "
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.use_memory:
|
if self.use_memory:
|
||||||
summary = self.summary_op(utterance.permute(1, 2, 0)).permute(
|
summary = self.summary_op(utterance.permute(1, 2, 0)).permute(
|
||||||
2, 0, 1
|
2, 0, 1
|
||||||
@ -851,8 +842,10 @@ class EmformerLayer(nn.Module):
|
|||||||
summary = torch.empty(0).to(
|
summary = torch.empty(0).to(
|
||||||
dtype=utterance.dtype, device=utterance.device
|
dtype=utterance.dtype, device=utterance.device
|
||||||
)
|
)
|
||||||
# pos_emb is of shape [PE, D], PE = L + 2 * U - 1,
|
# pos_emb is of shape [PE, D], where PE = L + 2 * U - 1,
|
||||||
# the relative distance j - i of key(j) and query(i) is in range of [-(L + U - 1), (U - 1)] # noqa
|
# for query of [utterance] (i), key-value [left_context, utterance] (j),
|
||||||
|
# the max relative distance i - j is L + U - 1
|
||||||
|
# the min relative distance i - j is -(U - 1)
|
||||||
L = left_context_key.size(0) # L <= left_context_length
|
L = left_context_key.size(0) # L <= left_context_length
|
||||||
U = utterance.size(0)
|
U = utterance.size(0)
|
||||||
PE = L + 2 * U - 1
|
PE = L + 2 * U - 1
|
||||||
@ -916,8 +909,8 @@ class EmformerLayer(nn.Module):
|
|||||||
Attention mask for underlying attention module,
|
Attention mask for underlying attention module,
|
||||||
with shape (Q, KV), where Q = R + U + S, KV = M + R + U.
|
with shape (Q, KV), where Q = R + U + S, KV = M + R + U.
|
||||||
pos_emb (torch.Tensor):
|
pos_emb (torch.Tensor):
|
||||||
Position encoding embedding, with shape (PE, D).
|
Position encoding embedding, with shape (PE, D),
|
||||||
For training mode, P = 2*U-1.
|
where PE = 2 * U - 1.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple containing 3 tensors:
|
A tuple containing 3 tensors:
|
||||||
@ -987,8 +980,8 @@ class EmformerLayer(nn.Module):
|
|||||||
List of tensors representing layer internal state generated in
|
List of tensors representing layer internal state generated in
|
||||||
preceding computation. (default=None)
|
preceding computation. (default=None)
|
||||||
pos_emb (torch.Tensor):
|
pos_emb (torch.Tensor):
|
||||||
Position encoding embedding, with shape (PE, D).
|
Position encoding embedding, with shape (PE, D),
|
||||||
For infer mode, PE = L+2*U-1.
|
where PE = L + 2 * U - 1.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(Tensor, Tensor, List[torch.Tensor], Tensor):
|
(Tensor, Tensor, List[torch.Tensor], Tensor):
|
||||||
@ -1073,7 +1066,6 @@ class EmformerEncoder(nn.Module):
|
|||||||
left_context_length: int = 0,
|
left_context_length: int = 0,
|
||||||
right_context_length: int = 0,
|
right_context_length: int = 0,
|
||||||
max_memory_size: int = 0,
|
max_memory_size: int = 0,
|
||||||
weight_init_scale_strategy: str = "depthwise",
|
|
||||||
tanh_on_mem: bool = False,
|
tanh_on_mem: bool = False,
|
||||||
negative_inf: float = -1e8,
|
negative_inf: float = -1e8,
|
||||||
):
|
):
|
||||||
@ -1104,6 +1096,8 @@ class EmformerEncoder(nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||||
|
|
||||||
self.left_context_length = left_context_length
|
self.left_context_length = left_context_length
|
||||||
self.right_context_length = right_context_length
|
self.right_context_length = right_context_length
|
||||||
self.chunk_length = chunk_length
|
self.chunk_length = chunk_length
|
||||||
@ -1246,10 +1240,7 @@ class EmformerEncoder(nn.Module):
|
|||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self, x: torch.Tensor, lengths: torch.Tensor
|
||||||
x: torch.Tensor,
|
|
||||||
lengths: torch.Tensor,
|
|
||||||
pos_emb: torch.Tensor,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Forward pass for training and non-streaming inference.
|
"""Forward pass for training and non-streaming inference.
|
||||||
|
|
||||||
@ -1265,9 +1256,6 @@ class EmformerEncoder(nn.Module):
|
|||||||
With shape (B,) and i-th element representing number of valid
|
With shape (B,) and i-th element representing number of valid
|
||||||
utterance frames for i-th batch element in x, which contains the
|
utterance frames for i-th batch element in x, which contains the
|
||||||
right_context at the end.
|
right_context at the end.
|
||||||
pos_emb (torch.Tensor):
|
|
||||||
Position encoding embedding, with shape (PE, D).
|
|
||||||
For training mode, P = 2*U-1.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of 2 tensors:
|
A tuple of 2 tensors:
|
||||||
@ -1275,8 +1263,11 @@ class EmformerEncoder(nn.Module):
|
|||||||
- output_lengths, with shape (B,), without containing the
|
- output_lengths, with shape (B,), without containing the
|
||||||
right_context at the end.
|
right_context at the end.
|
||||||
"""
|
"""
|
||||||
|
U = x.size(0) - self.right_context_length
|
||||||
|
x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U)
|
||||||
|
|
||||||
right_context = self._gen_right_context(x)
|
right_context = self._gen_right_context(x)
|
||||||
utterance = x[: x.size(0) - self.right_context_length]
|
utterance = x[:U]
|
||||||
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
||||||
attention_mask = self._gen_attention_mask(utterance)
|
attention_mask = self._gen_attention_mask(utterance)
|
||||||
memory = (
|
memory = (
|
||||||
@ -1286,6 +1277,7 @@ class EmformerEncoder(nn.Module):
|
|||||||
if self.use_memory
|
if self.use_memory
|
||||||
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
||||||
)
|
)
|
||||||
|
|
||||||
output = utterance
|
output = utterance
|
||||||
for layer in self.emformer_layers:
|
for layer in self.emformer_layers:
|
||||||
output, right_context, memory = layer(
|
output, right_context, memory = layer(
|
||||||
@ -1304,7 +1296,6 @@ class EmformerEncoder(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
lengths: torch.Tensor,
|
lengths: torch.Tensor,
|
||||||
pos_emb: torch.Tensor,
|
|
||||||
states: Optional[List[List[torch.Tensor]]] = None,
|
states: Optional[List[List[torch.Tensor]]] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
||||||
"""Forward pass for streaming inference.
|
"""Forward pass for streaming inference.
|
||||||
@ -1325,9 +1316,6 @@ class EmformerEncoder(nn.Module):
|
|||||||
Cached states from proceeding chunk's computation, where each
|
Cached states from proceeding chunk's computation, where each
|
||||||
element (List[torch.Tensor]) corresponding to each emformer layer.
|
element (List[torch.Tensor]) corresponding to each emformer layer.
|
||||||
(default: None)
|
(default: None)
|
||||||
pos_emb (torch.Tensor):
|
|
||||||
Position encoding embedding, with shape (PE, D).
|
|
||||||
For infer mode, PE = L+2*U-1.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(Tensor, Tensor, List[List[torch.Tensor]]):
|
(Tensor, Tensor, List[List[torch.Tensor]]):
|
||||||
@ -1341,9 +1329,12 @@ class EmformerEncoder(nn.Module):
|
|||||||
f"expected size of {self.chunk_length + self.right_context_length} "
|
f"expected size of {self.chunk_length + self.right_context_length} "
|
||||||
f"for dimension 1 of x, but got {x.size(1)}."
|
f"for dimension 1 of x, but got {x.size(1)}."
|
||||||
)
|
)
|
||||||
right_context_start_idx = x.size(0) - self.right_context_length
|
pos_len = self.chunk_length + self.left_context_length
|
||||||
right_context = x[right_context_start_idx:]
|
neg_len = self.chunk_length
|
||||||
utterance = x[:right_context_start_idx]
|
x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len)
|
||||||
|
|
||||||
|
right_context = x[self.chunk_length :]
|
||||||
|
utterance = x[: self.chunk_length]
|
||||||
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
||||||
memory = (
|
memory = (
|
||||||
self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
|
self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
|
||||||
@ -1383,7 +1374,6 @@ class Emformer(EncoderInterface):
|
|||||||
left_context_length: int = 0,
|
left_context_length: int = 0,
|
||||||
right_context_length: int = 0,
|
right_context_length: int = 0,
|
||||||
max_memory_size: int = 0,
|
max_memory_size: int = 0,
|
||||||
weight_init_scale_strategy: str = "depthwise",
|
|
||||||
tanh_on_mem: bool = False,
|
tanh_on_mem: bool = False,
|
||||||
negative_inf: float = -1e8,
|
negative_inf: float = -1e8,
|
||||||
):
|
):
|
||||||
@ -1416,8 +1406,6 @@ class Emformer(EncoderInterface):
|
|||||||
else:
|
else:
|
||||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||||
|
|
||||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
|
||||||
|
|
||||||
self.encoder = EmformerEncoder(
|
self.encoder = EmformerEncoder(
|
||||||
chunk_length // 4,
|
chunk_length // 4,
|
||||||
d_model,
|
d_model,
|
||||||
@ -1429,7 +1417,6 @@ class Emformer(EncoderInterface):
|
|||||||
left_context_length=left_context_length // 4,
|
left_context_length=left_context_length // 4,
|
||||||
right_context_length=right_context_length // 4,
|
right_context_length=right_context_length // 4,
|
||||||
max_memory_size=max_memory_size,
|
max_memory_size=max_memory_size,
|
||||||
weight_init_scale_strategy=weight_init_scale_strategy,
|
|
||||||
tanh_on_mem=tanh_on_mem,
|
tanh_on_mem=tanh_on_mem,
|
||||||
negative_inf=negative_inf,
|
negative_inf=negative_inf,
|
||||||
)
|
)
|
||||||
@ -1465,10 +1452,6 @@ class Emformer(EncoderInterface):
|
|||||||
right_context at the end.
|
right_context at the end.
|
||||||
"""
|
"""
|
||||||
x = self.encoder_embed(x)
|
x = self.encoder_embed(x)
|
||||||
|
|
||||||
# TODO: The length computation in the encoder class should be moved here. # noqa
|
|
||||||
U = x.size(1) - self.right_context_length // 4
|
|
||||||
x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U)
|
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
# Caution: We assume the subsampling factor is 4!
|
# Caution: We assume the subsampling factor is 4!
|
||||||
@ -1477,7 +1460,7 @@ class Emformer(EncoderInterface):
|
|||||||
x_lens = ((x_lens - 1) // 2 - 1) // 2
|
x_lens = ((x_lens - 1) // 2 - 1) // 2
|
||||||
assert x.size(0) == x_lens.max().item()
|
assert x.size(0) == x_lens.max().item()
|
||||||
|
|
||||||
output, output_lengths = self.encoder(x, x_lens, pos_emb) # (T, N, C)
|
output, output_lengths = self.encoder(x, x_lens) # (T, N, C)
|
||||||
|
|
||||||
logits = self.encoder_output_layer(output)
|
logits = self.encoder_output_layer(output)
|
||||||
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
@ -1518,12 +1501,6 @@ class Emformer(EncoderInterface):
|
|||||||
- updated states from current chunk's computation.
|
- updated states from current chunk's computation.
|
||||||
"""
|
"""
|
||||||
x = self.encoder_embed(x)
|
x = self.encoder_embed(x)
|
||||||
|
|
||||||
# TODO: The length computation in the encoder class should be moved here. # noqa
|
|
||||||
pos_len = self.chunk_length // 4 + self.left_context_length // 4
|
|
||||||
neg_len = self.chunk_length // 4
|
|
||||||
x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len)
|
|
||||||
|
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
# Caution: We assume the subsampling factor is 4!
|
# Caution: We assume the subsampling factor is 4!
|
||||||
@ -1533,7 +1510,7 @@ class Emformer(EncoderInterface):
|
|||||||
assert x.size(0) == x_lens.max().item()
|
assert x.size(0) == x_lens.max().item()
|
||||||
|
|
||||||
output, output_lengths, output_states = self.encoder.infer(
|
output, output_lengths, output_states = self.encoder.infer(
|
||||||
x, x_lens, pos_emb, states
|
x, x_lens, states
|
||||||
) # (T, N, C)
|
) # (T, N, C)
|
||||||
|
|
||||||
logits = self.encoder_output_layer(output)
|
logits = self.encoder_output_layer(output)
|
||||||
|
@ -199,7 +199,6 @@ def test_emformer_encoder_forward():
|
|||||||
chunk_length = 4
|
chunk_length = 4
|
||||||
right_context_length = 2
|
right_context_length = 2
|
||||||
left_context_length = 2
|
left_context_length = 2
|
||||||
left_context_length = 2
|
|
||||||
num_chunks = 3
|
num_chunks = 3
|
||||||
U = num_chunks * chunk_length
|
U = num_chunks * chunk_length
|
||||||
|
|
||||||
@ -223,10 +222,8 @@ def test_emformer_encoder_forward():
|
|||||||
x = torch.randn(U + right_context_length, B, D)
|
x = torch.randn(U + right_context_length, B, D)
|
||||||
lengths = torch.randint(1, U + right_context_length + 1, (B,))
|
lengths = torch.randint(1, U + right_context_length + 1, (B,))
|
||||||
lengths[0] = U + right_context_length
|
lengths[0] = U + right_context_length
|
||||||
PE = 2 * U - 1
|
|
||||||
pos_emb = torch.randn(PE, D)
|
|
||||||
|
|
||||||
output, output_lengths = encoder(x, lengths, pos_emb)
|
output, output_lengths = encoder(x, lengths)
|
||||||
assert output.shape == (U, B, D)
|
assert output.shape == (U, B, D)
|
||||||
assert torch.equal(
|
assert torch.equal(
|
||||||
output_lengths, torch.clamp(lengths - right_context_length, min=0)
|
output_lengths, torch.clamp(lengths - right_context_length, min=0)
|
||||||
@ -266,11 +263,7 @@ def test_emformer_encoder_infer():
|
|||||||
1, chunk_length + right_context_length + 1, (B,)
|
1, chunk_length + right_context_length + 1, (B,)
|
||||||
)
|
)
|
||||||
lengths[0] = chunk_length + right_context_length
|
lengths[0] = chunk_length + right_context_length
|
||||||
PE = left_context_length + 2 * chunk_length - 1
|
output, output_lengths, states = encoder.infer(x, lengths, states)
|
||||||
pos_emb = torch.randn(PE, D)
|
|
||||||
output, output_lengths, states = encoder.infer(
|
|
||||||
x, lengths, pos_emb, states
|
|
||||||
)
|
|
||||||
assert output.shape == (chunk_length, B, D)
|
assert output.shape == (chunk_length, B, D)
|
||||||
assert torch.equal(
|
assert torch.equal(
|
||||||
output_lengths,
|
output_lengths,
|
||||||
@ -383,6 +376,7 @@ def test_emformer_infer():
|
|||||||
|
|
||||||
|
|
||||||
def test_emformer_attention_forward_infer_consistency():
|
def test_emformer_attention_forward_infer_consistency():
|
||||||
|
# TODO: delete
|
||||||
from emformer import EmformerEncoder
|
from emformer import EmformerEncoder
|
||||||
|
|
||||||
chunk_length = 4
|
chunk_length = 4
|
||||||
@ -474,7 +468,7 @@ def test_emformer_layer_forward_infer_consistency():
|
|||||||
chunk_length = 4
|
chunk_length = 4
|
||||||
num_chunks = 3
|
num_chunks = 3
|
||||||
U = chunk_length * num_chunks
|
U = chunk_length * num_chunks
|
||||||
L, R = 1, 2
|
left_context_length, right_context_length = 1, 2
|
||||||
D = 256
|
D = 256
|
||||||
num_encoder_layers = 1
|
num_encoder_layers = 1
|
||||||
memory_sizes = [0, 3]
|
memory_sizes = [0, 3]
|
||||||
@ -485,18 +479,22 @@ def test_emformer_layer_forward_infer_consistency():
|
|||||||
d_model=D,
|
d_model=D,
|
||||||
dim_feedforward=1024,
|
dim_feedforward=1024,
|
||||||
num_encoder_layers=num_encoder_layers,
|
num_encoder_layers=num_encoder_layers,
|
||||||
left_context_length=L,
|
left_context_length=left_context_length,
|
||||||
right_context_length=R,
|
right_context_length=right_context_length,
|
||||||
max_memory_size=M,
|
max_memory_size=M,
|
||||||
dropout=0.1,
|
dropout=0.1,
|
||||||
)
|
)
|
||||||
encoder.eval()
|
encoder.eval()
|
||||||
encoder_layer = encoder.emformer_layers[0]
|
encoder_layer = encoder.emformer_layers[0]
|
||||||
|
encoder_pos = encoder.encoder_pos
|
||||||
|
|
||||||
x = torch.randn(U + R, 1, D)
|
x = torch.randn(U + right_context_length, 1, D)
|
||||||
|
|
||||||
|
# training mode with full utterance
|
||||||
|
x_forward, pos_emb = encoder_pos(x, U, U)
|
||||||
lengths = torch.tensor([U])
|
lengths = torch.tensor([U])
|
||||||
right_context = encoder._gen_right_context(x)
|
right_context = encoder._gen_right_context(x_forward)
|
||||||
utterance = x[: x.size(0) - R]
|
utterance = x_forward[:U]
|
||||||
attention_mask = encoder._gen_attention_mask(utterance)
|
attention_mask = encoder._gen_attention_mask(utterance)
|
||||||
memory = (
|
memory = (
|
||||||
encoder.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[
|
encoder.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[
|
||||||
@ -515,15 +513,20 @@ def test_emformer_layer_forward_infer_consistency():
|
|||||||
right_context,
|
right_context,
|
||||||
memory,
|
memory,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
pos_emb,
|
||||||
)
|
)
|
||||||
|
|
||||||
state = None
|
state = None
|
||||||
for chunk_idx in range(num_chunks):
|
for chunk_idx in range(num_chunks):
|
||||||
start_idx = chunk_idx * chunk_length
|
start_idx = chunk_idx * chunk_length
|
||||||
end_idx = start_idx + chunk_length
|
end_idx = start_idx + chunk_length
|
||||||
chunk = x[start_idx:end_idx]
|
cur_x, pos_emb = encoder_pos(
|
||||||
chunk_right_context = x[end_idx : end_idx + R] # noqa
|
x[start_idx : end_idx + right_context_length],
|
||||||
chunk_length = torch.tensor([chunk_length])
|
pos_len=chunk_length + left_context_length,
|
||||||
|
neg_len=chunk_length,
|
||||||
|
)
|
||||||
|
chunk = cur_x[:chunk_length]
|
||||||
|
chunk_right_context = cur_x[chunk_length:]
|
||||||
chunk_memory = (
|
chunk_memory = (
|
||||||
encoder.init_memory_op(chunk.permute(1, 2, 0)).permute(2, 0, 1)
|
encoder.init_memory_op(chunk.permute(1, 2, 0)).permute(2, 0, 1)
|
||||||
if encoder.use_memory
|
if encoder.use_memory
|
||||||
@ -536,9 +539,10 @@ def test_emformer_layer_forward_infer_consistency():
|
|||||||
state,
|
state,
|
||||||
) = encoder_layer.infer(
|
) = encoder_layer.infer(
|
||||||
chunk,
|
chunk,
|
||||||
chunk_length,
|
torch.tensor([chunk_length]),
|
||||||
chunk_right_context,
|
chunk_right_context,
|
||||||
chunk_memory,
|
chunk_memory,
|
||||||
|
pos_emb,
|
||||||
state,
|
state,
|
||||||
)
|
)
|
||||||
forward_output_chunk = forward_output_utterance[start_idx:end_idx]
|
forward_output_chunk = forward_output_utterance[start_idx:end_idx]
|
||||||
@ -551,7 +555,7 @@ def test_emformer_layer_forward_infer_consistency():
|
|||||||
|
|
||||||
|
|
||||||
def test_emformer_encoder_forward_infer_consistency():
|
def test_emformer_encoder_forward_infer_consistency():
|
||||||
from emformer import EmformerEncoder, RelPositionalEncoding
|
from emformer import EmformerEncoder
|
||||||
|
|
||||||
chunk_length = 4
|
chunk_length = 4
|
||||||
num_chunks = 3
|
num_chunks = 3
|
||||||
@ -573,28 +577,22 @@ def test_emformer_encoder_forward_infer_consistency():
|
|||||||
dropout=0.1,
|
dropout=0.1,
|
||||||
)
|
)
|
||||||
encoder.eval()
|
encoder.eval()
|
||||||
encoder_pos = RelPositionalEncoding(D, dropout_rate=0)
|
|
||||||
|
|
||||||
x = torch.randn(U + right_context_length, 1, D)
|
x = torch.randn(U + right_context_length, 1, D)
|
||||||
lengths = torch.tensor([U + right_context_length])
|
lengths = torch.tensor([U + right_context_length])
|
||||||
_, pos_emb = encoder_pos(x, U, U)
|
|
||||||
|
|
||||||
forward_output, forward_output_lengths = encoder(x, lengths, pos_emb)
|
# training mode with full utterance
|
||||||
|
forward_output, forward_output_lengths = encoder(x, lengths)
|
||||||
|
|
||||||
|
# streaming inference mode with individual chunks
|
||||||
states = None
|
states = None
|
||||||
_, pos_emb = encoder_pos(
|
|
||||||
x, chunk_length + left_context_length, chunk_length
|
|
||||||
)
|
|
||||||
for chunk_idx in range(num_chunks):
|
for chunk_idx in range(num_chunks):
|
||||||
start_idx = chunk_idx * chunk_length
|
start_idx = chunk_idx * chunk_length
|
||||||
end_idx = start_idx + chunk_length
|
end_idx = start_idx + chunk_length
|
||||||
chunk = x[start_idx : end_idx + right_context_length] # noqa
|
chunk = x[start_idx : end_idx + right_context_length] # noqa
|
||||||
chunk_length = torch.tensor([chunk_length])
|
chunk_length = torch.tensor([chunk_length])
|
||||||
infer_output_chunk, infer_output_lengths, states = encoder.infer(
|
infer_output_chunk, infer_output_lengths, states = encoder.infer(
|
||||||
chunk,
|
chunk, chunk_length, states
|
||||||
chunk_length,
|
|
||||||
pos_emb,
|
|
||||||
states,
|
|
||||||
)
|
)
|
||||||
forward_output_chunk = forward_output[start_idx:end_idx]
|
forward_output_chunk = forward_output[start_idx:end_idx]
|
||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
@ -615,7 +613,7 @@ def test_emformer_infer_batch_single_consistency():
|
|||||||
chunk_length = 8
|
chunk_length = 8
|
||||||
num_chunks = 3
|
num_chunks = 3
|
||||||
U = num_chunks * chunk_length
|
U = num_chunks * chunk_length
|
||||||
L, R = 128, 4
|
left_context_length, right_context_length = 128, 4
|
||||||
B, D = 2, 256
|
B, D = 2, 256
|
||||||
num_encoder_layers = 2
|
num_encoder_layers = 2
|
||||||
for use_memory in [True, False]:
|
for use_memory in [True, False]:
|
||||||
@ -630,8 +628,8 @@ def test_emformer_infer_batch_single_consistency():
|
|||||||
subsampling_factor=4,
|
subsampling_factor=4,
|
||||||
d_model=D,
|
d_model=D,
|
||||||
num_encoder_layers=num_encoder_layers,
|
num_encoder_layers=num_encoder_layers,
|
||||||
left_context_length=L,
|
left_context_length=left_context_length,
|
||||||
right_context_length=R,
|
right_context_length=right_context_length,
|
||||||
max_memory_size=M,
|
max_memory_size=M,
|
||||||
vgg_frontend=False,
|
vgg_frontend=False,
|
||||||
)
|
)
|
||||||
@ -689,20 +687,25 @@ def test_emformer_infer_batch_single_consistency():
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
x = torch.randn(B, U + R + 3, num_features)
|
x = torch.randn(B, U + right_context_length + 3, num_features)
|
||||||
|
|
||||||
|
# batch-wise inference
|
||||||
batch_logits = []
|
batch_logits = []
|
||||||
batch_states = []
|
batch_states = []
|
||||||
states = None
|
states = None
|
||||||
for chunk_idx in range(num_chunks):
|
for chunk_idx in range(num_chunks):
|
||||||
start_idx = chunk_idx * chunk_length
|
start_idx = chunk_idx * chunk_length
|
||||||
end_idx = start_idx + chunk_length
|
end_idx = start_idx + chunk_length
|
||||||
chunk = x[:, start_idx : end_idx + R + 3] # noqa
|
chunk = x[:, start_idx : end_idx + right_context_length + 3] # noqa
|
||||||
lengths = torch.tensor([chunk_length + R + 3]).expand(B)
|
lengths = torch.tensor(
|
||||||
|
[chunk_length + right_context_length + 3]
|
||||||
|
).expand(B)
|
||||||
logits, output_lengths, states = model.infer(chunk, lengths, states)
|
logits, output_lengths, states = model.infer(chunk, lengths, states)
|
||||||
batch_logits.append(logits)
|
batch_logits.append(logits)
|
||||||
batch_states.append(save_states(states))
|
batch_states.append(save_states(states))
|
||||||
batch_logits = torch.cat(batch_logits, dim=1)
|
batch_logits = torch.cat(batch_logits, dim=1)
|
||||||
|
|
||||||
|
# single-wise inference
|
||||||
single_logits = []
|
single_logits = []
|
||||||
for sample_idx in range(B):
|
for sample_idx in range(B):
|
||||||
sample = x[sample_idx : sample_idx + 1] # noqa
|
sample = x[sample_idx : sample_idx + 1] # noqa
|
||||||
@ -711,17 +714,21 @@ def test_emformer_infer_batch_single_consistency():
|
|||||||
for chunk_idx in range(num_chunks):
|
for chunk_idx in range(num_chunks):
|
||||||
start_idx = chunk_idx * chunk_length
|
start_idx = chunk_idx * chunk_length
|
||||||
end_idx = start_idx + chunk_length
|
end_idx = start_idx + chunk_length
|
||||||
chunk = sample[:, start_idx : end_idx + R + 3] # noqa
|
chunk = sample[
|
||||||
lengths = torch.tensor([chunk_length + R + 3])
|
:, start_idx : end_idx + right_context_length + 3
|
||||||
|
]
|
||||||
|
lengths = torch.tensor(
|
||||||
|
[chunk_length + right_context_length + 3]
|
||||||
|
)
|
||||||
logits, output_lengths, states = model.infer(
|
logits, output_lengths, states = model.infer(
|
||||||
chunk, lengths, states
|
chunk, lengths, states
|
||||||
)
|
)
|
||||||
chunk_logits.append(logits)
|
chunk_logits.append(logits)
|
||||||
|
|
||||||
assert_states_equal(batch_states[chunk_idx], states, sample_idx)
|
assert_states_equal(batch_states[chunk_idx], states, sample_idx)
|
||||||
|
|
||||||
chunk_logits = torch.cat(chunk_logits, dim=1)
|
chunk_logits = torch.cat(chunk_logits, dim=1)
|
||||||
single_logits.append(chunk_logits)
|
single_logits.append(chunk_logits)
|
||||||
|
|
||||||
single_logits = torch.cat(single_logits, dim=0)
|
single_logits = torch.cat(single_logits, dim=0)
|
||||||
|
|
||||||
assert torch.allclose(batch_logits, single_logits, atol=1e-5, rtol=0.0)
|
assert torch.allclose(batch_logits, single_logits, atol=1e-5, rtol=0.0)
|
||||||
@ -734,7 +741,7 @@ def test_emformer_infer_states_stack():
|
|||||||
output_dim = 1000
|
output_dim = 1000
|
||||||
chunk_length = 8
|
chunk_length = 8
|
||||||
U = chunk_length
|
U = chunk_length
|
||||||
L, R = 128, 4
|
left_context_length, right_context_length = 128, 4
|
||||||
B, D = 2, 256
|
B, D = 2, 256
|
||||||
num_encoder_layers = 2
|
num_encoder_layers = 2
|
||||||
for use_memory in [True, False]:
|
for use_memory in [True, False]:
|
||||||
@ -749,14 +756,14 @@ def test_emformer_infer_states_stack():
|
|||||||
subsampling_factor=4,
|
subsampling_factor=4,
|
||||||
d_model=D,
|
d_model=D,
|
||||||
num_encoder_layers=num_encoder_layers,
|
num_encoder_layers=num_encoder_layers,
|
||||||
left_context_length=L,
|
left_context_length=left_context_length,
|
||||||
right_context_length=R,
|
right_context_length=right_context_length,
|
||||||
max_memory_size=M,
|
max_memory_size=M,
|
||||||
vgg_frontend=False,
|
vgg_frontend=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
x = torch.randn(B, U + R + 3, num_features)
|
x = torch.randn(B, U + right_context_length + 3, num_features)
|
||||||
x_lens = torch.full((B,), U + R + 3)
|
x_lens = torch.full((B,), U + right_context_length + 3)
|
||||||
logits, output_lengths, states = model.infer(
|
logits, output_lengths, states = model.infer(
|
||||||
x,
|
x,
|
||||||
x_lens,
|
x_lens,
|
||||||
@ -790,8 +797,8 @@ if __name__ == "__main__":
|
|||||||
test_emformer_forward()
|
test_emformer_forward()
|
||||||
test_emformer_infer()
|
test_emformer_infer()
|
||||||
# test_emformer_attention_forward_infer_consistency()
|
# test_emformer_attention_forward_infer_consistency()
|
||||||
# test_emformer_layer_forward_infer_consistency()
|
test_emformer_layer_forward_infer_consistency()
|
||||||
test_emformer_encoder_forward_infer_consistency()
|
test_emformer_encoder_forward_infer_consistency()
|
||||||
# test_emformer_infer_batch_single_consistency()
|
test_emformer_infer_batch_single_consistency()
|
||||||
# test_emformer_infer_states_stack()
|
test_emformer_infer_states_stack()
|
||||||
test_rel_positional_encoding()
|
test_rel_positional_encoding()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user