minor refactor of emformer codes

This commit is contained in:
yaozengwei 2022-05-10 11:15:44 +08:00
parent aff7c4ee3c
commit e3a29b17f3
2 changed files with 138 additions and 154 deletions

View File

@ -40,24 +40,6 @@ def _get_activation_module(activation: str) -> nn.Module:
raise ValueError(f"Unsupported activation {activation}")
def _get_weight_init_gains(
weight_init_scale_strategy: Optional[str], num_layers: int
) -> List[Optional[float]]:
if weight_init_scale_strategy is None:
return [None for _ in range(num_layers)]
elif weight_init_scale_strategy == "depthwise":
return [
1.0 / math.sqrt(layer_idx + 1) for layer_idx in range(num_layers)
]
elif weight_init_scale_strategy == "constant":
return [1.0 / math.sqrt(2) for layer_idx in range(num_layers)]
else:
raise ValueError(
f"Unsupported weight_init_scale_strategy value"
f"{weight_init_scale_strategy}"
)
def _gen_attention_mask_block(
col_widths: List[int],
col_mask: List[bool],
@ -154,6 +136,8 @@ class EmformerAttention(nn.Module):
Embedding dimension.
nhead (int):
Number of attention heads in each Emformer layer.
dropout (float):
A Dropout layer on attn_output_weights. (Default: 0.0)
tanh_on_mem (bool, optional):
If ``True``, applies tanh to memory elements. (Default: ``False``)
negative_inf (float, optional):
@ -164,6 +148,7 @@ class EmformerAttention(nn.Module):
self,
embed_dim: int,
nhead: int,
dropout: float = 0.0,
tanh_on_mem: bool = False,
negative_inf: float = -1e8,
):
@ -173,13 +158,14 @@ class EmformerAttention(nn.Module):
raise ValueError(
f"embed_dim ({embed_dim}) is not a multiple of nhead ({nhead})."
)
self.embed_dim = embed_dim
self.nhead = nhead
self.tanh_on_mem = tanh_on_mem
self.negative_inf = negative_inf
self.head_dim = embed_dim // nhead
self.dropout = dropout
self.scaling = self.head_dim ** -0.5
self.emb_to_key_value = nn.Linear(embed_dim, 2 * embed_dim, bias=True)
@ -262,6 +248,9 @@ class EmformerAttention(nn.Module):
attention_weights_float, dim=-1
).type_as(attention_weights)
attention_probs = nn.functional.dropout(
attention_probs, p=self.dropout, training=self.training
)
return attention_probs
def _rel_shift(self, x: torch.Tensor) -> torch.Tensor:
@ -311,12 +300,12 @@ class EmformerAttention(nn.Module):
KV: length of attention key and value.
1) Concat right_context, utterance, summary,
and compute query tensor with length Q = R + U + S.
and compute query with length Q = R + U + S.
2) Concat memory, right_context, utterance,
and compute key, value tensors with length KV = M + R + U;
optionally with left_context_key and left_context_val (inference mode),
and compute key, value with length KV = M + R + U;
also with left_context_key and left_context_val for infererence mode,
then KV = M + R + L + U.
3) Compute entire attention scores with query, key, and value,
3) Compute entire attention scores with above query, key, and value,
then apply attention_mask to get underlying chunk-wise attention scores.
Args:
@ -335,14 +324,14 @@ class EmformerAttention(nn.Module):
Attention mask for underlying attention, with shape (Q, KV).
pos_emb (torch.Tensor):
Position encoding embedding, with shape (PE, D).
For training mode, PE = 2*U-1;
For infer mode, PE = L+2*U-1.
For training mode, PE = 2 * U - 1;
For inference mode, PE = L + 2 * U - 1.
left_context_key (torch,Tensor, optional):
Cached attention key of left context from preceding computation,
with shape (L, B, D).
with shape (L, B, D). It is used for inference mode.
left_context_val (torch.Tensor, optional):
Cached attention value of left context from preceding computation,
with shape (L, B, D).
with shape (L, B, D). It is used for inference mode.
Returns:
A tuple containing 4 tensors:
@ -355,23 +344,21 @@ class EmformerAttention(nn.Module):
R = right_context.size(0)
M = memory.size(0)
# Compute query with [right context, utterance, summary].
# compute query with [right context, utterance, summary].
query = self.emb_to_query(
torch.cat([right_context, utterance, summary])
)
# Compute key and value with [mems, right context, utterance].
# compute key and value with [mems, right context, utterance].
key, value = self.emb_to_key_value(
torch.cat([memory, right_context, utterance])
).chunk(chunks=2, dim=2)
if left_context_key is not None and left_context_val is not None:
# This is for inference mode. Now compute key and value with
# compute key and value with
# [mems, right context, left context, uttrance]
key = torch.cat(
[key[: M + R], left_context_key, key[M + R :]] # noqa
)
key = torch.cat([key[: M + R], left_context_key, key[M + R :]])
value = torch.cat(
[value[: M + R], left_context_val, value[M + R :]] # noqa
[value[: M + R], left_context_val, value[M + R :]]
)
Q = query.size(0)
KV = key.size(0)
@ -381,12 +368,14 @@ class EmformerAttention(nn.Module):
.view(KV, B * self.nhead, self.head_dim)
.transpose(0, 1)
for tensor in [key, value]
] # (B * nhead, KV, head_dim)
] # both of shape (B * nhead, KV, head_dim)
reshaped_query = query.contiguous().view(
Q, B, self.nhead, self.head_dim
)
# compute attention matrix ac
# compute attention score
# first compute attention matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # noqa
query_with_bais_u = (
(reshaped_query + self.pos_bias_u)
.view(Q, B * self.nhead, self.head_dim)
@ -396,7 +385,9 @@ class EmformerAttention(nn.Module):
query_with_bais_u, reshaped_key.transpose(1, 2)
) # (B * nhead, Q, KV)
# compute attention matrix bd
# second, compute attention matrix b and matrix d
# relative positional encoding is applied on the part of attention
# between query: [utterance] -> key, value: [left_context, utterance]
utterance_with_bais_v = (
reshaped_query[R : R + U] + self.pos_bias_v
).permute(1, 2, 0, 3)
@ -416,10 +407,10 @@ class EmformerAttention(nn.Module):
matrix_bd_utterance = torch.matmul(
utterance_with_bais_v, pos_emb.transpose(-2, -1)
) # (B, nhead, U, PE)
# rel-shift
matrix_bd_utterance = self._rel_shift(
matrix_bd_utterance
) # (B, nhead, U, U or L + U)
# rel-shift operation
matrix_bd_utterance = self._rel_shift(matrix_bd_utterance)
# (B, nhead, U, U) for training mode;
# (B, nhead, U, L + U) for inference mode.
matrix_bd_utterance = matrix_bd_utterance.contiguous().view(
B * self.nhead, U, -1
)
@ -428,25 +419,25 @@ class EmformerAttention(nn.Module):
attention_weights = (matrix_ac + matrix_bd) * self.scaling
# Compute padding mask
# compute padding mask
if B == 1:
padding_mask = None
else:
padding_mask = make_pad_mask(KV - U + lengths)
# Compute attention probabilities.
# compute attention probabilities
attention_probs = self._gen_attention_probs(
attention_weights, attention_mask, padding_mask
)
# Compute attention.
# compute attention outputs
attention = torch.bmm(attention_probs, reshaped_value)
assert attention.shape == (B * self.nhead, Q, self.head_dim)
attention = (
attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim)
)
# Apply output projection.
# apply output projection
outputs = self.out_proj(attention)
output_right_context_utterance = outputs[: R + U]
@ -487,7 +478,7 @@ class EmformerAttention(nn.Module):
right_context (torch.Tensor):
Right context frames, with shape (R, B, D).
summary (torch.Tensor):
Summary elements, with shape (S, B, D).
Summary elements with shape (S, B, D) or an empty tensor.
memory (torch.Tensor):
Memory elements, with shape (M, B, D).
attention_mask (torch.Tensor):
@ -495,7 +486,7 @@ class EmformerAttention(nn.Module):
with shape (Q, KV), where Q = R + U + S, KV = M + R + U.
pos_emb (torch.Tensor):
Position encoding embedding, with shape (PE, D).
For training mode, P = 2*U-1.
where PE = 2 * U - 1.
Returns:
A tuple containing 2 tensors:
@ -549,7 +540,7 @@ class EmformerAttention(nn.Module):
right_context (torch.Tensor):
Right context frames, with shape (R, B, D).
summary (torch.Tensor):
Summary element, with shape (1, B, D), or empty.
Summary element with shape (1, B, D), or an empty tensor.
memory (torch.Tensor):
Memory elements, with shape (M, B, D).
left_context_key (torch,Tensor):
@ -571,19 +562,20 @@ class EmformerAttention(nn.Module):
- attention value of left context and utterance, which would be
cached for next computation, with shape (L + U, B, D).
"""
U = utterance.size(0)
R = right_context.size(0)
L = left_context_key.size(0)
S = summary.size(0)
M = memory.size(0)
# query: [right context, utterance, summary]
Q = right_context.size(0) + utterance.size(0) + summary.size(0)
Q = R + U + S
# key, value: [memory, right context, left context, uttrance]
KV = (
memory.size(0)
+ right_context.size(0) # noqa
+ left_context_key.size(0) # noqa
+ utterance.size(0) # noqa
)
KV = M + R + L + U
attention_mask = torch.zeros(Q, KV).to(
dtype=torch.bool, device=utterance.device
)
# Disallow attention bettween the summary vector with the memory bank
# disallow attention bettween the summary vector with the memory bank
attention_mask[-1, : memory.size(0)] = True
(
output_right_context_utterance,
@ -601,12 +593,11 @@ class EmformerAttention(nn.Module):
left_context_key=left_context_key,
left_context_val=left_context_val,
)
right_context_end_idx = memory.size(0) + right_context.size(0)
return (
output_right_context_utterance,
output_memory,
key[right_context_end_idx:],
value[right_context_end_idx:],
key[M + R :],
value[M + R :],
)
@ -656,6 +647,7 @@ class EmformerLayer(nn.Module):
self.attention = EmformerAttention(
embed_dim=d_model,
nhead=nhead,
dropout=dropout,
tanh_on_mem=tanh_on_mem,
negative_inf=negative_inf,
)
@ -756,9 +748,9 @@ class EmformerLayer(nn.Module):
layer_norm_input = self.layer_norm_input(
torch.cat([right_context, utterance])
)
right_context_end_idx = right_context.size(0)
layer_norm_utterance = layer_norm_input[right_context_end_idx:]
layer_norm_right_context = layer_norm_input[:right_context_end_idx]
R = right_context.size(0)
layer_norm_utterance = layer_norm_input[R:]
layer_norm_right_context = layer_norm_input[:R]
return layer_norm_utterance, layer_norm_right_context
def _apply_post_attention_ffn_layer_norm(
@ -768,18 +760,18 @@ class EmformerLayer(nn.Module):
right_context: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply feed forward and layer normalization after attention."""
# Apply residual connection between input and attention output.
# apply residual connection between input and attention output.
result = self.dropout(output_right_context_utterance) + torch.cat(
[right_context, utterance]
)
# Apply feedforward module and residual connection.
# apply feedforward module and residual connection.
result = self.pos_ff(result) + result
# Apply layer normalization for output.
# apply layer normalization for output.
result = self.layer_norm_output(result)
right_context_end_idx = right_context.size(0)
output_utterance = result[right_context_end_idx:]
output_right_context = result[:right_context_end_idx]
R = right_context.size(0)
output_utterance = result[R:]
output_right_context = result[:R]
return output_utterance, output_right_context
def _apply_attention_forward(
@ -796,7 +788,6 @@ class EmformerLayer(nn.Module):
raise ValueError(
"attention_mask must be not None in non-infer mode. "
)
if self.use_memory:
summary = self.summary_op(utterance.permute(1, 2, 0)).permute(
2, 0, 1
@ -851,8 +842,10 @@ class EmformerLayer(nn.Module):
summary = torch.empty(0).to(
dtype=utterance.dtype, device=utterance.device
)
# pos_emb is of shape [PE, D], PE = L + 2 * U - 1,
# the relative distance j - i of key(j) and query(i) is in range of [-(L + U - 1), (U - 1)] # noqa
# pos_emb is of shape [PE, D], where PE = L + 2 * U - 1,
# for query of [utterance] (i), key-value [left_context, utterance] (j),
# the max relative distance i - j is L + U - 1
# the min relative distance i - j is -(U - 1)
L = left_context_key.size(0) # L <= left_context_length
U = utterance.size(0)
PE = L + 2 * U - 1
@ -916,8 +909,8 @@ class EmformerLayer(nn.Module):
Attention mask for underlying attention module,
with shape (Q, KV), where Q = R + U + S, KV = M + R + U.
pos_emb (torch.Tensor):
Position encoding embedding, with shape (PE, D).
For training mode, P = 2*U-1.
Position encoding embedding, with shape (PE, D),
where PE = 2 * U - 1.
Returns:
A tuple containing 3 tensors:
@ -987,8 +980,8 @@ class EmformerLayer(nn.Module):
List of tensors representing layer internal state generated in
preceding computation. (default=None)
pos_emb (torch.Tensor):
Position encoding embedding, with shape (PE, D).
For infer mode, PE = L+2*U-1.
Position encoding embedding, with shape (PE, D),
where PE = L + 2 * U - 1.
Returns:
(Tensor, Tensor, List[torch.Tensor], Tensor):
@ -1073,7 +1066,6 @@ class EmformerEncoder(nn.Module):
left_context_length: int = 0,
right_context_length: int = 0,
max_memory_size: int = 0,
weight_init_scale_strategy: str = "depthwise",
tanh_on_mem: bool = False,
negative_inf: float = -1e8,
):
@ -1104,6 +1096,8 @@ class EmformerEncoder(nn.Module):
]
)
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
self.left_context_length = left_context_length
self.right_context_length = right_context_length
self.chunk_length = chunk_length
@ -1246,10 +1240,7 @@ class EmformerEncoder(nn.Module):
return attention_mask
def forward(
self,
x: torch.Tensor,
lengths: torch.Tensor,
pos_emb: torch.Tensor,
self, x: torch.Tensor, lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass for training and non-streaming inference.
@ -1265,9 +1256,6 @@ class EmformerEncoder(nn.Module):
With shape (B,) and i-th element representing number of valid
utterance frames for i-th batch element in x, which contains the
right_context at the end.
pos_emb (torch.Tensor):
Position encoding embedding, with shape (PE, D).
For training mode, P = 2*U-1.
Returns:
A tuple of 2 tensors:
@ -1275,8 +1263,11 @@ class EmformerEncoder(nn.Module):
- output_lengths, with shape (B,), without containing the
right_context at the end.
"""
U = x.size(0) - self.right_context_length
x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U)
right_context = self._gen_right_context(x)
utterance = x[: x.size(0) - self.right_context_length]
utterance = x[:U]
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
attention_mask = self._gen_attention_mask(utterance)
memory = (
@ -1286,6 +1277,7 @@ class EmformerEncoder(nn.Module):
if self.use_memory
else torch.empty(0).to(dtype=x.dtype, device=x.device)
)
output = utterance
for layer in self.emformer_layers:
output, right_context, memory = layer(
@ -1304,7 +1296,6 @@ class EmformerEncoder(nn.Module):
self,
x: torch.Tensor,
lengths: torch.Tensor,
pos_emb: torch.Tensor,
states: Optional[List[List[torch.Tensor]]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
"""Forward pass for streaming inference.
@ -1325,9 +1316,6 @@ class EmformerEncoder(nn.Module):
Cached states from proceeding chunk's computation, where each
element (List[torch.Tensor]) corresponding to each emformer layer.
(default: None)
pos_emb (torch.Tensor):
Position encoding embedding, with shape (PE, D).
For infer mode, PE = L+2*U-1.
Returns:
(Tensor, Tensor, List[List[torch.Tensor]]):
@ -1341,9 +1329,12 @@ class EmformerEncoder(nn.Module):
f"expected size of {self.chunk_length + self.right_context_length} "
f"for dimension 1 of x, but got {x.size(1)}."
)
right_context_start_idx = x.size(0) - self.right_context_length
right_context = x[right_context_start_idx:]
utterance = x[:right_context_start_idx]
pos_len = self.chunk_length + self.left_context_length
neg_len = self.chunk_length
x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len)
right_context = x[self.chunk_length :]
utterance = x[: self.chunk_length]
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
memory = (
self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
@ -1383,7 +1374,6 @@ class Emformer(EncoderInterface):
left_context_length: int = 0,
right_context_length: int = 0,
max_memory_size: int = 0,
weight_init_scale_strategy: str = "depthwise",
tanh_on_mem: bool = False,
negative_inf: float = -1e8,
):
@ -1416,8 +1406,6 @@ class Emformer(EncoderInterface):
else:
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
self.encoder = EmformerEncoder(
chunk_length // 4,
d_model,
@ -1429,7 +1417,6 @@ class Emformer(EncoderInterface):
left_context_length=left_context_length // 4,
right_context_length=right_context_length // 4,
max_memory_size=max_memory_size,
weight_init_scale_strategy=weight_init_scale_strategy,
tanh_on_mem=tanh_on_mem,
negative_inf=negative_inf,
)
@ -1465,10 +1452,6 @@ class Emformer(EncoderInterface):
right_context at the end.
"""
x = self.encoder_embed(x)
# TODO: The length computation in the encoder class should be moved here. # noqa
U = x.size(1) - self.right_context_length // 4
x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# Caution: We assume the subsampling factor is 4!
@ -1477,7 +1460,7 @@ class Emformer(EncoderInterface):
x_lens = ((x_lens - 1) // 2 - 1) // 2
assert x.size(0) == x_lens.max().item()
output, output_lengths = self.encoder(x, x_lens, pos_emb) # (T, N, C)
output, output_lengths = self.encoder(x, x_lens) # (T, N, C)
logits = self.encoder_output_layer(output)
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
@ -1518,12 +1501,6 @@ class Emformer(EncoderInterface):
- updated states from current chunk's computation.
"""
x = self.encoder_embed(x)
# TODO: The length computation in the encoder class should be moved here. # noqa
pos_len = self.chunk_length // 4 + self.left_context_length // 4
neg_len = self.chunk_length // 4
x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# Caution: We assume the subsampling factor is 4!
@ -1533,7 +1510,7 @@ class Emformer(EncoderInterface):
assert x.size(0) == x_lens.max().item()
output, output_lengths, output_states = self.encoder.infer(
x, x_lens, pos_emb, states
x, x_lens, states
) # (T, N, C)
logits = self.encoder_output_layer(output)

View File

@ -199,7 +199,6 @@ def test_emformer_encoder_forward():
chunk_length = 4
right_context_length = 2
left_context_length = 2
left_context_length = 2
num_chunks = 3
U = num_chunks * chunk_length
@ -223,10 +222,8 @@ def test_emformer_encoder_forward():
x = torch.randn(U + right_context_length, B, D)
lengths = torch.randint(1, U + right_context_length + 1, (B,))
lengths[0] = U + right_context_length
PE = 2 * U - 1
pos_emb = torch.randn(PE, D)
output, output_lengths = encoder(x, lengths, pos_emb)
output, output_lengths = encoder(x, lengths)
assert output.shape == (U, B, D)
assert torch.equal(
output_lengths, torch.clamp(lengths - right_context_length, min=0)
@ -266,11 +263,7 @@ def test_emformer_encoder_infer():
1, chunk_length + right_context_length + 1, (B,)
)
lengths[0] = chunk_length + right_context_length
PE = left_context_length + 2 * chunk_length - 1
pos_emb = torch.randn(PE, D)
output, output_lengths, states = encoder.infer(
x, lengths, pos_emb, states
)
output, output_lengths, states = encoder.infer(x, lengths, states)
assert output.shape == (chunk_length, B, D)
assert torch.equal(
output_lengths,
@ -383,6 +376,7 @@ def test_emformer_infer():
def test_emformer_attention_forward_infer_consistency():
# TODO: delete
from emformer import EmformerEncoder
chunk_length = 4
@ -474,7 +468,7 @@ def test_emformer_layer_forward_infer_consistency():
chunk_length = 4
num_chunks = 3
U = chunk_length * num_chunks
L, R = 1, 2
left_context_length, right_context_length = 1, 2
D = 256
num_encoder_layers = 1
memory_sizes = [0, 3]
@ -485,18 +479,22 @@ def test_emformer_layer_forward_infer_consistency():
d_model=D,
dim_feedforward=1024,
num_encoder_layers=num_encoder_layers,
left_context_length=L,
right_context_length=R,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=M,
dropout=0.1,
)
encoder.eval()
encoder_layer = encoder.emformer_layers[0]
encoder_pos = encoder.encoder_pos
x = torch.randn(U + R, 1, D)
x = torch.randn(U + right_context_length, 1, D)
# training mode with full utterance
x_forward, pos_emb = encoder_pos(x, U, U)
lengths = torch.tensor([U])
right_context = encoder._gen_right_context(x)
utterance = x[: x.size(0) - R]
right_context = encoder._gen_right_context(x_forward)
utterance = x_forward[:U]
attention_mask = encoder._gen_attention_mask(utterance)
memory = (
encoder.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[
@ -515,15 +513,20 @@ def test_emformer_layer_forward_infer_consistency():
right_context,
memory,
attention_mask,
pos_emb,
)
state = None
for chunk_idx in range(num_chunks):
start_idx = chunk_idx * chunk_length
end_idx = start_idx + chunk_length
chunk = x[start_idx:end_idx]
chunk_right_context = x[end_idx : end_idx + R] # noqa
chunk_length = torch.tensor([chunk_length])
cur_x, pos_emb = encoder_pos(
x[start_idx : end_idx + right_context_length],
pos_len=chunk_length + left_context_length,
neg_len=chunk_length,
)
chunk = cur_x[:chunk_length]
chunk_right_context = cur_x[chunk_length:]
chunk_memory = (
encoder.init_memory_op(chunk.permute(1, 2, 0)).permute(2, 0, 1)
if encoder.use_memory
@ -536,9 +539,10 @@ def test_emformer_layer_forward_infer_consistency():
state,
) = encoder_layer.infer(
chunk,
chunk_length,
torch.tensor([chunk_length]),
chunk_right_context,
chunk_memory,
pos_emb,
state,
)
forward_output_chunk = forward_output_utterance[start_idx:end_idx]
@ -551,7 +555,7 @@ def test_emformer_layer_forward_infer_consistency():
def test_emformer_encoder_forward_infer_consistency():
from emformer import EmformerEncoder, RelPositionalEncoding
from emformer import EmformerEncoder
chunk_length = 4
num_chunks = 3
@ -573,28 +577,22 @@ def test_emformer_encoder_forward_infer_consistency():
dropout=0.1,
)
encoder.eval()
encoder_pos = RelPositionalEncoding(D, dropout_rate=0)
x = torch.randn(U + right_context_length, 1, D)
lengths = torch.tensor([U + right_context_length])
_, pos_emb = encoder_pos(x, U, U)
forward_output, forward_output_lengths = encoder(x, lengths, pos_emb)
# training mode with full utterance
forward_output, forward_output_lengths = encoder(x, lengths)
# streaming inference mode with individual chunks
states = None
_, pos_emb = encoder_pos(
x, chunk_length + left_context_length, chunk_length
)
for chunk_idx in range(num_chunks):
start_idx = chunk_idx * chunk_length
end_idx = start_idx + chunk_length
chunk = x[start_idx : end_idx + right_context_length] # noqa
chunk_length = torch.tensor([chunk_length])
infer_output_chunk, infer_output_lengths, states = encoder.infer(
chunk,
chunk_length,
pos_emb,
states,
chunk, chunk_length, states
)
forward_output_chunk = forward_output[start_idx:end_idx]
assert torch.allclose(
@ -615,7 +613,7 @@ def test_emformer_infer_batch_single_consistency():
chunk_length = 8
num_chunks = 3
U = num_chunks * chunk_length
L, R = 128, 4
left_context_length, right_context_length = 128, 4
B, D = 2, 256
num_encoder_layers = 2
for use_memory in [True, False]:
@ -630,8 +628,8 @@ def test_emformer_infer_batch_single_consistency():
subsampling_factor=4,
d_model=D,
num_encoder_layers=num_encoder_layers,
left_context_length=L,
right_context_length=R,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=M,
vgg_frontend=False,
)
@ -689,20 +687,25 @@ def test_emformer_infer_batch_single_consistency():
],
)
x = torch.randn(B, U + R + 3, num_features)
x = torch.randn(B, U + right_context_length + 3, num_features)
# batch-wise inference
batch_logits = []
batch_states = []
states = None
for chunk_idx in range(num_chunks):
start_idx = chunk_idx * chunk_length
end_idx = start_idx + chunk_length
chunk = x[:, start_idx : end_idx + R + 3] # noqa
lengths = torch.tensor([chunk_length + R + 3]).expand(B)
chunk = x[:, start_idx : end_idx + right_context_length + 3] # noqa
lengths = torch.tensor(
[chunk_length + right_context_length + 3]
).expand(B)
logits, output_lengths, states = model.infer(chunk, lengths, states)
batch_logits.append(logits)
batch_states.append(save_states(states))
batch_logits = torch.cat(batch_logits, dim=1)
# single-wise inference
single_logits = []
for sample_idx in range(B):
sample = x[sample_idx : sample_idx + 1] # noqa
@ -711,17 +714,21 @@ def test_emformer_infer_batch_single_consistency():
for chunk_idx in range(num_chunks):
start_idx = chunk_idx * chunk_length
end_idx = start_idx + chunk_length
chunk = sample[:, start_idx : end_idx + R + 3] # noqa
lengths = torch.tensor([chunk_length + R + 3])
chunk = sample[
:, start_idx : end_idx + right_context_length + 3
]
lengths = torch.tensor(
[chunk_length + right_context_length + 3]
)
logits, output_lengths, states = model.infer(
chunk, lengths, states
)
chunk_logits.append(logits)
assert_states_equal(batch_states[chunk_idx], states, sample_idx)
chunk_logits = torch.cat(chunk_logits, dim=1)
single_logits.append(chunk_logits)
single_logits = torch.cat(single_logits, dim=0)
assert torch.allclose(batch_logits, single_logits, atol=1e-5, rtol=0.0)
@ -734,7 +741,7 @@ def test_emformer_infer_states_stack():
output_dim = 1000
chunk_length = 8
U = chunk_length
L, R = 128, 4
left_context_length, right_context_length = 128, 4
B, D = 2, 256
num_encoder_layers = 2
for use_memory in [True, False]:
@ -749,14 +756,14 @@ def test_emformer_infer_states_stack():
subsampling_factor=4,
d_model=D,
num_encoder_layers=num_encoder_layers,
left_context_length=L,
right_context_length=R,
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=M,
vgg_frontend=False,
)
x = torch.randn(B, U + R + 3, num_features)
x_lens = torch.full((B,), U + R + 3)
x = torch.randn(B, U + right_context_length + 3, num_features)
x_lens = torch.full((B,), U + right_context_length + 3)
logits, output_lengths, states = model.infer(
x,
x_lens,
@ -790,8 +797,8 @@ if __name__ == "__main__":
test_emformer_forward()
test_emformer_infer()
# test_emformer_attention_forward_infer_consistency()
# test_emformer_layer_forward_infer_consistency()
test_emformer_layer_forward_infer_consistency()
test_emformer_encoder_forward_infer_consistency()
# test_emformer_infer_batch_single_consistency()
# test_emformer_infer_states_stack()
test_emformer_infer_batch_single_consistency()
test_emformer_infer_states_stack()
test_rel_positional_encoding()