mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
minor refactor of emformer codes
This commit is contained in:
parent
aff7c4ee3c
commit
e3a29b17f3
@ -40,24 +40,6 @@ def _get_activation_module(activation: str) -> nn.Module:
|
||||
raise ValueError(f"Unsupported activation {activation}")
|
||||
|
||||
|
||||
def _get_weight_init_gains(
|
||||
weight_init_scale_strategy: Optional[str], num_layers: int
|
||||
) -> List[Optional[float]]:
|
||||
if weight_init_scale_strategy is None:
|
||||
return [None for _ in range(num_layers)]
|
||||
elif weight_init_scale_strategy == "depthwise":
|
||||
return [
|
||||
1.0 / math.sqrt(layer_idx + 1) for layer_idx in range(num_layers)
|
||||
]
|
||||
elif weight_init_scale_strategy == "constant":
|
||||
return [1.0 / math.sqrt(2) for layer_idx in range(num_layers)]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported weight_init_scale_strategy value"
|
||||
f"{weight_init_scale_strategy}"
|
||||
)
|
||||
|
||||
|
||||
def _gen_attention_mask_block(
|
||||
col_widths: List[int],
|
||||
col_mask: List[bool],
|
||||
@ -154,6 +136,8 @@ class EmformerAttention(nn.Module):
|
||||
Embedding dimension.
|
||||
nhead (int):
|
||||
Number of attention heads in each Emformer layer.
|
||||
dropout (float):
|
||||
A Dropout layer on attn_output_weights. (Default: 0.0)
|
||||
tanh_on_mem (bool, optional):
|
||||
If ``True``, applies tanh to memory elements. (Default: ``False``)
|
||||
negative_inf (float, optional):
|
||||
@ -164,6 +148,7 @@ class EmformerAttention(nn.Module):
|
||||
self,
|
||||
embed_dim: int,
|
||||
nhead: int,
|
||||
dropout: float = 0.0,
|
||||
tanh_on_mem: bool = False,
|
||||
negative_inf: float = -1e8,
|
||||
):
|
||||
@ -173,13 +158,14 @@ class EmformerAttention(nn.Module):
|
||||
raise ValueError(
|
||||
f"embed_dim ({embed_dim}) is not a multiple of nhead ({nhead})."
|
||||
)
|
||||
|
||||
self.embed_dim = embed_dim
|
||||
self.nhead = nhead
|
||||
self.tanh_on_mem = tanh_on_mem
|
||||
self.negative_inf = negative_inf
|
||||
self.head_dim = embed_dim // nhead
|
||||
|
||||
self.dropout = dropout
|
||||
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
self.emb_to_key_value = nn.Linear(embed_dim, 2 * embed_dim, bias=True)
|
||||
@ -262,6 +248,9 @@ class EmformerAttention(nn.Module):
|
||||
attention_weights_float, dim=-1
|
||||
).type_as(attention_weights)
|
||||
|
||||
attention_probs = nn.functional.dropout(
|
||||
attention_probs, p=self.dropout, training=self.training
|
||||
)
|
||||
return attention_probs
|
||||
|
||||
def _rel_shift(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -311,12 +300,12 @@ class EmformerAttention(nn.Module):
|
||||
KV: length of attention key and value.
|
||||
|
||||
1) Concat right_context, utterance, summary,
|
||||
and compute query tensor with length Q = R + U + S.
|
||||
and compute query with length Q = R + U + S.
|
||||
2) Concat memory, right_context, utterance,
|
||||
and compute key, value tensors with length KV = M + R + U;
|
||||
optionally with left_context_key and left_context_val (inference mode),
|
||||
and compute key, value with length KV = M + R + U;
|
||||
also with left_context_key and left_context_val for infererence mode,
|
||||
then KV = M + R + L + U.
|
||||
3) Compute entire attention scores with query, key, and value,
|
||||
3) Compute entire attention scores with above query, key, and value,
|
||||
then apply attention_mask to get underlying chunk-wise attention scores.
|
||||
|
||||
Args:
|
||||
@ -335,14 +324,14 @@ class EmformerAttention(nn.Module):
|
||||
Attention mask for underlying attention, with shape (Q, KV).
|
||||
pos_emb (torch.Tensor):
|
||||
Position encoding embedding, with shape (PE, D).
|
||||
For training mode, PE = 2*U-1;
|
||||
For infer mode, PE = L+2*U-1.
|
||||
For training mode, PE = 2 * U - 1;
|
||||
For inference mode, PE = L + 2 * U - 1.
|
||||
left_context_key (torch,Tensor, optional):
|
||||
Cached attention key of left context from preceding computation,
|
||||
with shape (L, B, D).
|
||||
with shape (L, B, D). It is used for inference mode.
|
||||
left_context_val (torch.Tensor, optional):
|
||||
Cached attention value of left context from preceding computation,
|
||||
with shape (L, B, D).
|
||||
with shape (L, B, D). It is used for inference mode.
|
||||
|
||||
Returns:
|
||||
A tuple containing 4 tensors:
|
||||
@ -355,23 +344,21 @@ class EmformerAttention(nn.Module):
|
||||
R = right_context.size(0)
|
||||
M = memory.size(0)
|
||||
|
||||
# Compute query with [right context, utterance, summary].
|
||||
# compute query with [right context, utterance, summary].
|
||||
query = self.emb_to_query(
|
||||
torch.cat([right_context, utterance, summary])
|
||||
)
|
||||
# Compute key and value with [mems, right context, utterance].
|
||||
# compute key and value with [mems, right context, utterance].
|
||||
key, value = self.emb_to_key_value(
|
||||
torch.cat([memory, right_context, utterance])
|
||||
).chunk(chunks=2, dim=2)
|
||||
|
||||
if left_context_key is not None and left_context_val is not None:
|
||||
# This is for inference mode. Now compute key and value with
|
||||
# compute key and value with
|
||||
# [mems, right context, left context, uttrance]
|
||||
key = torch.cat(
|
||||
[key[: M + R], left_context_key, key[M + R :]] # noqa
|
||||
)
|
||||
key = torch.cat([key[: M + R], left_context_key, key[M + R :]])
|
||||
value = torch.cat(
|
||||
[value[: M + R], left_context_val, value[M + R :]] # noqa
|
||||
[value[: M + R], left_context_val, value[M + R :]]
|
||||
)
|
||||
Q = query.size(0)
|
||||
KV = key.size(0)
|
||||
@ -381,12 +368,14 @@ class EmformerAttention(nn.Module):
|
||||
.view(KV, B * self.nhead, self.head_dim)
|
||||
.transpose(0, 1)
|
||||
for tensor in [key, value]
|
||||
] # (B * nhead, KV, head_dim)
|
||||
] # both of shape (B * nhead, KV, head_dim)
|
||||
reshaped_query = query.contiguous().view(
|
||||
Q, B, self.nhead, self.head_dim
|
||||
)
|
||||
|
||||
# compute attention matrix ac
|
||||
# compute attention score
|
||||
# first compute attention matrix a and matrix c
|
||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # noqa
|
||||
query_with_bais_u = (
|
||||
(reshaped_query + self.pos_bias_u)
|
||||
.view(Q, B * self.nhead, self.head_dim)
|
||||
@ -396,7 +385,9 @@ class EmformerAttention(nn.Module):
|
||||
query_with_bais_u, reshaped_key.transpose(1, 2)
|
||||
) # (B * nhead, Q, KV)
|
||||
|
||||
# compute attention matrix bd
|
||||
# second, compute attention matrix b and matrix d
|
||||
# relative positional encoding is applied on the part of attention
|
||||
# between query: [utterance] -> key, value: [left_context, utterance]
|
||||
utterance_with_bais_v = (
|
||||
reshaped_query[R : R + U] + self.pos_bias_v
|
||||
).permute(1, 2, 0, 3)
|
||||
@ -416,10 +407,10 @@ class EmformerAttention(nn.Module):
|
||||
matrix_bd_utterance = torch.matmul(
|
||||
utterance_with_bais_v, pos_emb.transpose(-2, -1)
|
||||
) # (B, nhead, U, PE)
|
||||
# rel-shift
|
||||
matrix_bd_utterance = self._rel_shift(
|
||||
matrix_bd_utterance
|
||||
) # (B, nhead, U, U or L + U)
|
||||
# rel-shift operation
|
||||
matrix_bd_utterance = self._rel_shift(matrix_bd_utterance)
|
||||
# (B, nhead, U, U) for training mode;
|
||||
# (B, nhead, U, L + U) for inference mode.
|
||||
matrix_bd_utterance = matrix_bd_utterance.contiguous().view(
|
||||
B * self.nhead, U, -1
|
||||
)
|
||||
@ -428,25 +419,25 @@ class EmformerAttention(nn.Module):
|
||||
|
||||
attention_weights = (matrix_ac + matrix_bd) * self.scaling
|
||||
|
||||
# Compute padding mask
|
||||
# compute padding mask
|
||||
if B == 1:
|
||||
padding_mask = None
|
||||
else:
|
||||
padding_mask = make_pad_mask(KV - U + lengths)
|
||||
|
||||
# Compute attention probabilities.
|
||||
# compute attention probabilities
|
||||
attention_probs = self._gen_attention_probs(
|
||||
attention_weights, attention_mask, padding_mask
|
||||
)
|
||||
|
||||
# Compute attention.
|
||||
# compute attention outputs
|
||||
attention = torch.bmm(attention_probs, reshaped_value)
|
||||
assert attention.shape == (B * self.nhead, Q, self.head_dim)
|
||||
attention = (
|
||||
attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim)
|
||||
)
|
||||
|
||||
# Apply output projection.
|
||||
# apply output projection
|
||||
outputs = self.out_proj(attention)
|
||||
|
||||
output_right_context_utterance = outputs[: R + U]
|
||||
@ -487,7 +478,7 @@ class EmformerAttention(nn.Module):
|
||||
right_context (torch.Tensor):
|
||||
Right context frames, with shape (R, B, D).
|
||||
summary (torch.Tensor):
|
||||
Summary elements, with shape (S, B, D).
|
||||
Summary elements with shape (S, B, D) or an empty tensor.
|
||||
memory (torch.Tensor):
|
||||
Memory elements, with shape (M, B, D).
|
||||
attention_mask (torch.Tensor):
|
||||
@ -495,7 +486,7 @@ class EmformerAttention(nn.Module):
|
||||
with shape (Q, KV), where Q = R + U + S, KV = M + R + U.
|
||||
pos_emb (torch.Tensor):
|
||||
Position encoding embedding, with shape (PE, D).
|
||||
For training mode, P = 2*U-1.
|
||||
where PE = 2 * U - 1.
|
||||
|
||||
Returns:
|
||||
A tuple containing 2 tensors:
|
||||
@ -549,7 +540,7 @@ class EmformerAttention(nn.Module):
|
||||
right_context (torch.Tensor):
|
||||
Right context frames, with shape (R, B, D).
|
||||
summary (torch.Tensor):
|
||||
Summary element, with shape (1, B, D), or empty.
|
||||
Summary element with shape (1, B, D), or an empty tensor.
|
||||
memory (torch.Tensor):
|
||||
Memory elements, with shape (M, B, D).
|
||||
left_context_key (torch,Tensor):
|
||||
@ -571,19 +562,20 @@ class EmformerAttention(nn.Module):
|
||||
- attention value of left context and utterance, which would be
|
||||
cached for next computation, with shape (L + U, B, D).
|
||||
"""
|
||||
U = utterance.size(0)
|
||||
R = right_context.size(0)
|
||||
L = left_context_key.size(0)
|
||||
S = summary.size(0)
|
||||
M = memory.size(0)
|
||||
|
||||
# query: [right context, utterance, summary]
|
||||
Q = right_context.size(0) + utterance.size(0) + summary.size(0)
|
||||
Q = R + U + S
|
||||
# key, value: [memory, right context, left context, uttrance]
|
||||
KV = (
|
||||
memory.size(0)
|
||||
+ right_context.size(0) # noqa
|
||||
+ left_context_key.size(0) # noqa
|
||||
+ utterance.size(0) # noqa
|
||||
)
|
||||
KV = M + R + L + U
|
||||
attention_mask = torch.zeros(Q, KV).to(
|
||||
dtype=torch.bool, device=utterance.device
|
||||
)
|
||||
# Disallow attention bettween the summary vector with the memory bank
|
||||
# disallow attention bettween the summary vector with the memory bank
|
||||
attention_mask[-1, : memory.size(0)] = True
|
||||
(
|
||||
output_right_context_utterance,
|
||||
@ -601,12 +593,11 @@ class EmformerAttention(nn.Module):
|
||||
left_context_key=left_context_key,
|
||||
left_context_val=left_context_val,
|
||||
)
|
||||
right_context_end_idx = memory.size(0) + right_context.size(0)
|
||||
return (
|
||||
output_right_context_utterance,
|
||||
output_memory,
|
||||
key[right_context_end_idx:],
|
||||
value[right_context_end_idx:],
|
||||
key[M + R :],
|
||||
value[M + R :],
|
||||
)
|
||||
|
||||
|
||||
@ -656,6 +647,7 @@ class EmformerLayer(nn.Module):
|
||||
self.attention = EmformerAttention(
|
||||
embed_dim=d_model,
|
||||
nhead=nhead,
|
||||
dropout=dropout,
|
||||
tanh_on_mem=tanh_on_mem,
|
||||
negative_inf=negative_inf,
|
||||
)
|
||||
@ -756,9 +748,9 @@ class EmformerLayer(nn.Module):
|
||||
layer_norm_input = self.layer_norm_input(
|
||||
torch.cat([right_context, utterance])
|
||||
)
|
||||
right_context_end_idx = right_context.size(0)
|
||||
layer_norm_utterance = layer_norm_input[right_context_end_idx:]
|
||||
layer_norm_right_context = layer_norm_input[:right_context_end_idx]
|
||||
R = right_context.size(0)
|
||||
layer_norm_utterance = layer_norm_input[R:]
|
||||
layer_norm_right_context = layer_norm_input[:R]
|
||||
return layer_norm_utterance, layer_norm_right_context
|
||||
|
||||
def _apply_post_attention_ffn_layer_norm(
|
||||
@ -768,18 +760,18 @@ class EmformerLayer(nn.Module):
|
||||
right_context: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Apply feed forward and layer normalization after attention."""
|
||||
# Apply residual connection between input and attention output.
|
||||
# apply residual connection between input and attention output.
|
||||
result = self.dropout(output_right_context_utterance) + torch.cat(
|
||||
[right_context, utterance]
|
||||
)
|
||||
# Apply feedforward module and residual connection.
|
||||
# apply feedforward module and residual connection.
|
||||
result = self.pos_ff(result) + result
|
||||
# Apply layer normalization for output.
|
||||
# apply layer normalization for output.
|
||||
result = self.layer_norm_output(result)
|
||||
|
||||
right_context_end_idx = right_context.size(0)
|
||||
output_utterance = result[right_context_end_idx:]
|
||||
output_right_context = result[:right_context_end_idx]
|
||||
R = right_context.size(0)
|
||||
output_utterance = result[R:]
|
||||
output_right_context = result[:R]
|
||||
return output_utterance, output_right_context
|
||||
|
||||
def _apply_attention_forward(
|
||||
@ -796,7 +788,6 @@ class EmformerLayer(nn.Module):
|
||||
raise ValueError(
|
||||
"attention_mask must be not None in non-infer mode. "
|
||||
)
|
||||
|
||||
if self.use_memory:
|
||||
summary = self.summary_op(utterance.permute(1, 2, 0)).permute(
|
||||
2, 0, 1
|
||||
@ -851,8 +842,10 @@ class EmformerLayer(nn.Module):
|
||||
summary = torch.empty(0).to(
|
||||
dtype=utterance.dtype, device=utterance.device
|
||||
)
|
||||
# pos_emb is of shape [PE, D], PE = L + 2 * U - 1,
|
||||
# the relative distance j - i of key(j) and query(i) is in range of [-(L + U - 1), (U - 1)] # noqa
|
||||
# pos_emb is of shape [PE, D], where PE = L + 2 * U - 1,
|
||||
# for query of [utterance] (i), key-value [left_context, utterance] (j),
|
||||
# the max relative distance i - j is L + U - 1
|
||||
# the min relative distance i - j is -(U - 1)
|
||||
L = left_context_key.size(0) # L <= left_context_length
|
||||
U = utterance.size(0)
|
||||
PE = L + 2 * U - 1
|
||||
@ -916,8 +909,8 @@ class EmformerLayer(nn.Module):
|
||||
Attention mask for underlying attention module,
|
||||
with shape (Q, KV), where Q = R + U + S, KV = M + R + U.
|
||||
pos_emb (torch.Tensor):
|
||||
Position encoding embedding, with shape (PE, D).
|
||||
For training mode, P = 2*U-1.
|
||||
Position encoding embedding, with shape (PE, D),
|
||||
where PE = 2 * U - 1.
|
||||
|
||||
Returns:
|
||||
A tuple containing 3 tensors:
|
||||
@ -987,8 +980,8 @@ class EmformerLayer(nn.Module):
|
||||
List of tensors representing layer internal state generated in
|
||||
preceding computation. (default=None)
|
||||
pos_emb (torch.Tensor):
|
||||
Position encoding embedding, with shape (PE, D).
|
||||
For infer mode, PE = L+2*U-1.
|
||||
Position encoding embedding, with shape (PE, D),
|
||||
where PE = L + 2 * U - 1.
|
||||
|
||||
Returns:
|
||||
(Tensor, Tensor, List[torch.Tensor], Tensor):
|
||||
@ -1073,7 +1066,6 @@ class EmformerEncoder(nn.Module):
|
||||
left_context_length: int = 0,
|
||||
right_context_length: int = 0,
|
||||
max_memory_size: int = 0,
|
||||
weight_init_scale_strategy: str = "depthwise",
|
||||
tanh_on_mem: bool = False,
|
||||
negative_inf: float = -1e8,
|
||||
):
|
||||
@ -1104,6 +1096,8 @@ class EmformerEncoder(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||
|
||||
self.left_context_length = left_context_length
|
||||
self.right_context_length = right_context_length
|
||||
self.chunk_length = chunk_length
|
||||
@ -1246,10 +1240,7 @@ class EmformerEncoder(nn.Module):
|
||||
return attention_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
self, x: torch.Tensor, lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward pass for training and non-streaming inference.
|
||||
|
||||
@ -1265,9 +1256,6 @@ class EmformerEncoder(nn.Module):
|
||||
With shape (B,) and i-th element representing number of valid
|
||||
utterance frames for i-th batch element in x, which contains the
|
||||
right_context at the end.
|
||||
pos_emb (torch.Tensor):
|
||||
Position encoding embedding, with shape (PE, D).
|
||||
For training mode, P = 2*U-1.
|
||||
|
||||
Returns:
|
||||
A tuple of 2 tensors:
|
||||
@ -1275,8 +1263,11 @@ class EmformerEncoder(nn.Module):
|
||||
- output_lengths, with shape (B,), without containing the
|
||||
right_context at the end.
|
||||
"""
|
||||
U = x.size(0) - self.right_context_length
|
||||
x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U)
|
||||
|
||||
right_context = self._gen_right_context(x)
|
||||
utterance = x[: x.size(0) - self.right_context_length]
|
||||
utterance = x[:U]
|
||||
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
||||
attention_mask = self._gen_attention_mask(utterance)
|
||||
memory = (
|
||||
@ -1286,6 +1277,7 @@ class EmformerEncoder(nn.Module):
|
||||
if self.use_memory
|
||||
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
||||
)
|
||||
|
||||
output = utterance
|
||||
for layer in self.emformer_layers:
|
||||
output, right_context, memory = layer(
|
||||
@ -1304,7 +1296,6 @@ class EmformerEncoder(nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
states: Optional[List[List[torch.Tensor]]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
||||
"""Forward pass for streaming inference.
|
||||
@ -1325,9 +1316,6 @@ class EmformerEncoder(nn.Module):
|
||||
Cached states from proceeding chunk's computation, where each
|
||||
element (List[torch.Tensor]) corresponding to each emformer layer.
|
||||
(default: None)
|
||||
pos_emb (torch.Tensor):
|
||||
Position encoding embedding, with shape (PE, D).
|
||||
For infer mode, PE = L+2*U-1.
|
||||
|
||||
Returns:
|
||||
(Tensor, Tensor, List[List[torch.Tensor]]):
|
||||
@ -1341,9 +1329,12 @@ class EmformerEncoder(nn.Module):
|
||||
f"expected size of {self.chunk_length + self.right_context_length} "
|
||||
f"for dimension 1 of x, but got {x.size(1)}."
|
||||
)
|
||||
right_context_start_idx = x.size(0) - self.right_context_length
|
||||
right_context = x[right_context_start_idx:]
|
||||
utterance = x[:right_context_start_idx]
|
||||
pos_len = self.chunk_length + self.left_context_length
|
||||
neg_len = self.chunk_length
|
||||
x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len)
|
||||
|
||||
right_context = x[self.chunk_length :]
|
||||
utterance = x[: self.chunk_length]
|
||||
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
||||
memory = (
|
||||
self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
|
||||
@ -1383,7 +1374,6 @@ class Emformer(EncoderInterface):
|
||||
left_context_length: int = 0,
|
||||
right_context_length: int = 0,
|
||||
max_memory_size: int = 0,
|
||||
weight_init_scale_strategy: str = "depthwise",
|
||||
tanh_on_mem: bool = False,
|
||||
negative_inf: float = -1e8,
|
||||
):
|
||||
@ -1416,8 +1406,6 @@ class Emformer(EncoderInterface):
|
||||
else:
|
||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||
|
||||
self.encoder = EmformerEncoder(
|
||||
chunk_length // 4,
|
||||
d_model,
|
||||
@ -1429,7 +1417,6 @@ class Emformer(EncoderInterface):
|
||||
left_context_length=left_context_length // 4,
|
||||
right_context_length=right_context_length // 4,
|
||||
max_memory_size=max_memory_size,
|
||||
weight_init_scale_strategy=weight_init_scale_strategy,
|
||||
tanh_on_mem=tanh_on_mem,
|
||||
negative_inf=negative_inf,
|
||||
)
|
||||
@ -1465,10 +1452,6 @@ class Emformer(EncoderInterface):
|
||||
right_context at the end.
|
||||
"""
|
||||
x = self.encoder_embed(x)
|
||||
|
||||
# TODO: The length computation in the encoder class should be moved here. # noqa
|
||||
U = x.size(1) - self.right_context_length // 4
|
||||
x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
# Caution: We assume the subsampling factor is 4!
|
||||
@ -1477,7 +1460,7 @@ class Emformer(EncoderInterface):
|
||||
x_lens = ((x_lens - 1) // 2 - 1) // 2
|
||||
assert x.size(0) == x_lens.max().item()
|
||||
|
||||
output, output_lengths = self.encoder(x, x_lens, pos_emb) # (T, N, C)
|
||||
output, output_lengths = self.encoder(x, x_lens) # (T, N, C)
|
||||
|
||||
logits = self.encoder_output_layer(output)
|
||||
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
@ -1518,12 +1501,6 @@ class Emformer(EncoderInterface):
|
||||
- updated states from current chunk's computation.
|
||||
"""
|
||||
x = self.encoder_embed(x)
|
||||
|
||||
# TODO: The length computation in the encoder class should be moved here. # noqa
|
||||
pos_len = self.chunk_length // 4 + self.left_context_length // 4
|
||||
neg_len = self.chunk_length // 4
|
||||
x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len)
|
||||
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
# Caution: We assume the subsampling factor is 4!
|
||||
@ -1533,7 +1510,7 @@ class Emformer(EncoderInterface):
|
||||
assert x.size(0) == x_lens.max().item()
|
||||
|
||||
output, output_lengths, output_states = self.encoder.infer(
|
||||
x, x_lens, pos_emb, states
|
||||
x, x_lens, states
|
||||
) # (T, N, C)
|
||||
|
||||
logits = self.encoder_output_layer(output)
|
||||
|
@ -199,7 +199,6 @@ def test_emformer_encoder_forward():
|
||||
chunk_length = 4
|
||||
right_context_length = 2
|
||||
left_context_length = 2
|
||||
left_context_length = 2
|
||||
num_chunks = 3
|
||||
U = num_chunks * chunk_length
|
||||
|
||||
@ -223,10 +222,8 @@ def test_emformer_encoder_forward():
|
||||
x = torch.randn(U + right_context_length, B, D)
|
||||
lengths = torch.randint(1, U + right_context_length + 1, (B,))
|
||||
lengths[0] = U + right_context_length
|
||||
PE = 2 * U - 1
|
||||
pos_emb = torch.randn(PE, D)
|
||||
|
||||
output, output_lengths = encoder(x, lengths, pos_emb)
|
||||
output, output_lengths = encoder(x, lengths)
|
||||
assert output.shape == (U, B, D)
|
||||
assert torch.equal(
|
||||
output_lengths, torch.clamp(lengths - right_context_length, min=0)
|
||||
@ -266,11 +263,7 @@ def test_emformer_encoder_infer():
|
||||
1, chunk_length + right_context_length + 1, (B,)
|
||||
)
|
||||
lengths[0] = chunk_length + right_context_length
|
||||
PE = left_context_length + 2 * chunk_length - 1
|
||||
pos_emb = torch.randn(PE, D)
|
||||
output, output_lengths, states = encoder.infer(
|
||||
x, lengths, pos_emb, states
|
||||
)
|
||||
output, output_lengths, states = encoder.infer(x, lengths, states)
|
||||
assert output.shape == (chunk_length, B, D)
|
||||
assert torch.equal(
|
||||
output_lengths,
|
||||
@ -383,6 +376,7 @@ def test_emformer_infer():
|
||||
|
||||
|
||||
def test_emformer_attention_forward_infer_consistency():
|
||||
# TODO: delete
|
||||
from emformer import EmformerEncoder
|
||||
|
||||
chunk_length = 4
|
||||
@ -474,7 +468,7 @@ def test_emformer_layer_forward_infer_consistency():
|
||||
chunk_length = 4
|
||||
num_chunks = 3
|
||||
U = chunk_length * num_chunks
|
||||
L, R = 1, 2
|
||||
left_context_length, right_context_length = 1, 2
|
||||
D = 256
|
||||
num_encoder_layers = 1
|
||||
memory_sizes = [0, 3]
|
||||
@ -485,18 +479,22 @@ def test_emformer_layer_forward_infer_consistency():
|
||||
d_model=D,
|
||||
dim_feedforward=1024,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
left_context_length=L,
|
||||
right_context_length=R,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=M,
|
||||
dropout=0.1,
|
||||
)
|
||||
encoder.eval()
|
||||
encoder_layer = encoder.emformer_layers[0]
|
||||
encoder_pos = encoder.encoder_pos
|
||||
|
||||
x = torch.randn(U + R, 1, D)
|
||||
x = torch.randn(U + right_context_length, 1, D)
|
||||
|
||||
# training mode with full utterance
|
||||
x_forward, pos_emb = encoder_pos(x, U, U)
|
||||
lengths = torch.tensor([U])
|
||||
right_context = encoder._gen_right_context(x)
|
||||
utterance = x[: x.size(0) - R]
|
||||
right_context = encoder._gen_right_context(x_forward)
|
||||
utterance = x_forward[:U]
|
||||
attention_mask = encoder._gen_attention_mask(utterance)
|
||||
memory = (
|
||||
encoder.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[
|
||||
@ -515,15 +513,20 @@ def test_emformer_layer_forward_infer_consistency():
|
||||
right_context,
|
||||
memory,
|
||||
attention_mask,
|
||||
pos_emb,
|
||||
)
|
||||
|
||||
state = None
|
||||
for chunk_idx in range(num_chunks):
|
||||
start_idx = chunk_idx * chunk_length
|
||||
end_idx = start_idx + chunk_length
|
||||
chunk = x[start_idx:end_idx]
|
||||
chunk_right_context = x[end_idx : end_idx + R] # noqa
|
||||
chunk_length = torch.tensor([chunk_length])
|
||||
cur_x, pos_emb = encoder_pos(
|
||||
x[start_idx : end_idx + right_context_length],
|
||||
pos_len=chunk_length + left_context_length,
|
||||
neg_len=chunk_length,
|
||||
)
|
||||
chunk = cur_x[:chunk_length]
|
||||
chunk_right_context = cur_x[chunk_length:]
|
||||
chunk_memory = (
|
||||
encoder.init_memory_op(chunk.permute(1, 2, 0)).permute(2, 0, 1)
|
||||
if encoder.use_memory
|
||||
@ -536,9 +539,10 @@ def test_emformer_layer_forward_infer_consistency():
|
||||
state,
|
||||
) = encoder_layer.infer(
|
||||
chunk,
|
||||
chunk_length,
|
||||
torch.tensor([chunk_length]),
|
||||
chunk_right_context,
|
||||
chunk_memory,
|
||||
pos_emb,
|
||||
state,
|
||||
)
|
||||
forward_output_chunk = forward_output_utterance[start_idx:end_idx]
|
||||
@ -551,7 +555,7 @@ def test_emformer_layer_forward_infer_consistency():
|
||||
|
||||
|
||||
def test_emformer_encoder_forward_infer_consistency():
|
||||
from emformer import EmformerEncoder, RelPositionalEncoding
|
||||
from emformer import EmformerEncoder
|
||||
|
||||
chunk_length = 4
|
||||
num_chunks = 3
|
||||
@ -573,28 +577,22 @@ def test_emformer_encoder_forward_infer_consistency():
|
||||
dropout=0.1,
|
||||
)
|
||||
encoder.eval()
|
||||
encoder_pos = RelPositionalEncoding(D, dropout_rate=0)
|
||||
|
||||
x = torch.randn(U + right_context_length, 1, D)
|
||||
lengths = torch.tensor([U + right_context_length])
|
||||
_, pos_emb = encoder_pos(x, U, U)
|
||||
|
||||
forward_output, forward_output_lengths = encoder(x, lengths, pos_emb)
|
||||
# training mode with full utterance
|
||||
forward_output, forward_output_lengths = encoder(x, lengths)
|
||||
|
||||
# streaming inference mode with individual chunks
|
||||
states = None
|
||||
_, pos_emb = encoder_pos(
|
||||
x, chunk_length + left_context_length, chunk_length
|
||||
)
|
||||
for chunk_idx in range(num_chunks):
|
||||
start_idx = chunk_idx * chunk_length
|
||||
end_idx = start_idx + chunk_length
|
||||
chunk = x[start_idx : end_idx + right_context_length] # noqa
|
||||
chunk_length = torch.tensor([chunk_length])
|
||||
infer_output_chunk, infer_output_lengths, states = encoder.infer(
|
||||
chunk,
|
||||
chunk_length,
|
||||
pos_emb,
|
||||
states,
|
||||
chunk, chunk_length, states
|
||||
)
|
||||
forward_output_chunk = forward_output[start_idx:end_idx]
|
||||
assert torch.allclose(
|
||||
@ -615,7 +613,7 @@ def test_emformer_infer_batch_single_consistency():
|
||||
chunk_length = 8
|
||||
num_chunks = 3
|
||||
U = num_chunks * chunk_length
|
||||
L, R = 128, 4
|
||||
left_context_length, right_context_length = 128, 4
|
||||
B, D = 2, 256
|
||||
num_encoder_layers = 2
|
||||
for use_memory in [True, False]:
|
||||
@ -630,8 +628,8 @@ def test_emformer_infer_batch_single_consistency():
|
||||
subsampling_factor=4,
|
||||
d_model=D,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
left_context_length=L,
|
||||
right_context_length=R,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=M,
|
||||
vgg_frontend=False,
|
||||
)
|
||||
@ -689,20 +687,25 @@ def test_emformer_infer_batch_single_consistency():
|
||||
],
|
||||
)
|
||||
|
||||
x = torch.randn(B, U + R + 3, num_features)
|
||||
x = torch.randn(B, U + right_context_length + 3, num_features)
|
||||
|
||||
# batch-wise inference
|
||||
batch_logits = []
|
||||
batch_states = []
|
||||
states = None
|
||||
for chunk_idx in range(num_chunks):
|
||||
start_idx = chunk_idx * chunk_length
|
||||
end_idx = start_idx + chunk_length
|
||||
chunk = x[:, start_idx : end_idx + R + 3] # noqa
|
||||
lengths = torch.tensor([chunk_length + R + 3]).expand(B)
|
||||
chunk = x[:, start_idx : end_idx + right_context_length + 3] # noqa
|
||||
lengths = torch.tensor(
|
||||
[chunk_length + right_context_length + 3]
|
||||
).expand(B)
|
||||
logits, output_lengths, states = model.infer(chunk, lengths, states)
|
||||
batch_logits.append(logits)
|
||||
batch_states.append(save_states(states))
|
||||
batch_logits = torch.cat(batch_logits, dim=1)
|
||||
|
||||
# single-wise inference
|
||||
single_logits = []
|
||||
for sample_idx in range(B):
|
||||
sample = x[sample_idx : sample_idx + 1] # noqa
|
||||
@ -711,17 +714,21 @@ def test_emformer_infer_batch_single_consistency():
|
||||
for chunk_idx in range(num_chunks):
|
||||
start_idx = chunk_idx * chunk_length
|
||||
end_idx = start_idx + chunk_length
|
||||
chunk = sample[:, start_idx : end_idx + R + 3] # noqa
|
||||
lengths = torch.tensor([chunk_length + R + 3])
|
||||
chunk = sample[
|
||||
:, start_idx : end_idx + right_context_length + 3
|
||||
]
|
||||
lengths = torch.tensor(
|
||||
[chunk_length + right_context_length + 3]
|
||||
)
|
||||
logits, output_lengths, states = model.infer(
|
||||
chunk, lengths, states
|
||||
)
|
||||
chunk_logits.append(logits)
|
||||
|
||||
assert_states_equal(batch_states[chunk_idx], states, sample_idx)
|
||||
|
||||
chunk_logits = torch.cat(chunk_logits, dim=1)
|
||||
single_logits.append(chunk_logits)
|
||||
|
||||
single_logits = torch.cat(single_logits, dim=0)
|
||||
|
||||
assert torch.allclose(batch_logits, single_logits, atol=1e-5, rtol=0.0)
|
||||
@ -734,7 +741,7 @@ def test_emformer_infer_states_stack():
|
||||
output_dim = 1000
|
||||
chunk_length = 8
|
||||
U = chunk_length
|
||||
L, R = 128, 4
|
||||
left_context_length, right_context_length = 128, 4
|
||||
B, D = 2, 256
|
||||
num_encoder_layers = 2
|
||||
for use_memory in [True, False]:
|
||||
@ -749,14 +756,14 @@ def test_emformer_infer_states_stack():
|
||||
subsampling_factor=4,
|
||||
d_model=D,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
left_context_length=L,
|
||||
right_context_length=R,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=M,
|
||||
vgg_frontend=False,
|
||||
)
|
||||
|
||||
x = torch.randn(B, U + R + 3, num_features)
|
||||
x_lens = torch.full((B,), U + R + 3)
|
||||
x = torch.randn(B, U + right_context_length + 3, num_features)
|
||||
x_lens = torch.full((B,), U + right_context_length + 3)
|
||||
logits, output_lengths, states = model.infer(
|
||||
x,
|
||||
x_lens,
|
||||
@ -790,8 +797,8 @@ if __name__ == "__main__":
|
||||
test_emformer_forward()
|
||||
test_emformer_infer()
|
||||
# test_emformer_attention_forward_infer_consistency()
|
||||
# test_emformer_layer_forward_infer_consistency()
|
||||
test_emformer_layer_forward_infer_consistency()
|
||||
test_emformer_encoder_forward_infer_consistency()
|
||||
# test_emformer_infer_batch_single_consistency()
|
||||
# test_emformer_infer_states_stack()
|
||||
test_emformer_infer_batch_single_consistency()
|
||||
test_emformer_infer_states_stack()
|
||||
test_rel_positional_encoding()
|
||||
|
Loading…
x
Reference in New Issue
Block a user