diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py index e8482944c..6ecc8d420 100644 --- a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py @@ -40,24 +40,6 @@ def _get_activation_module(activation: str) -> nn.Module: raise ValueError(f"Unsupported activation {activation}") -def _get_weight_init_gains( - weight_init_scale_strategy: Optional[str], num_layers: int -) -> List[Optional[float]]: - if weight_init_scale_strategy is None: - return [None for _ in range(num_layers)] - elif weight_init_scale_strategy == "depthwise": - return [ - 1.0 / math.sqrt(layer_idx + 1) for layer_idx in range(num_layers) - ] - elif weight_init_scale_strategy == "constant": - return [1.0 / math.sqrt(2) for layer_idx in range(num_layers)] - else: - raise ValueError( - f"Unsupported weight_init_scale_strategy value" - f"{weight_init_scale_strategy}" - ) - - def _gen_attention_mask_block( col_widths: List[int], col_mask: List[bool], @@ -154,6 +136,8 @@ class EmformerAttention(nn.Module): Embedding dimension. nhead (int): Number of attention heads in each Emformer layer. + dropout (float): + A Dropout layer on attn_output_weights. (Default: 0.0) tanh_on_mem (bool, optional): If ``True``, applies tanh to memory elements. (Default: ``False``) negative_inf (float, optional): @@ -164,6 +148,7 @@ class EmformerAttention(nn.Module): self, embed_dim: int, nhead: int, + dropout: float = 0.0, tanh_on_mem: bool = False, negative_inf: float = -1e8, ): @@ -173,13 +158,14 @@ class EmformerAttention(nn.Module): raise ValueError( f"embed_dim ({embed_dim}) is not a multiple of nhead ({nhead})." ) - self.embed_dim = embed_dim self.nhead = nhead self.tanh_on_mem = tanh_on_mem self.negative_inf = negative_inf self.head_dim = embed_dim // nhead + self.dropout = dropout + self.scaling = self.head_dim ** -0.5 self.emb_to_key_value = nn.Linear(embed_dim, 2 * embed_dim, bias=True) @@ -262,6 +248,9 @@ class EmformerAttention(nn.Module): attention_weights_float, dim=-1 ).type_as(attention_weights) + attention_probs = nn.functional.dropout( + attention_probs, p=self.dropout, training=self.training + ) return attention_probs def _rel_shift(self, x: torch.Tensor) -> torch.Tensor: @@ -311,12 +300,12 @@ class EmformerAttention(nn.Module): KV: length of attention key and value. 1) Concat right_context, utterance, summary, - and compute query tensor with length Q = R + U + S. + and compute query with length Q = R + U + S. 2) Concat memory, right_context, utterance, - and compute key, value tensors with length KV = M + R + U; - optionally with left_context_key and left_context_val (inference mode), + and compute key, value with length KV = M + R + U; + also with left_context_key and left_context_val for infererence mode, then KV = M + R + L + U. - 3) Compute entire attention scores with query, key, and value, + 3) Compute entire attention scores with above query, key, and value, then apply attention_mask to get underlying chunk-wise attention scores. Args: @@ -335,14 +324,14 @@ class EmformerAttention(nn.Module): Attention mask for underlying attention, with shape (Q, KV). pos_emb (torch.Tensor): Position encoding embedding, with shape (PE, D). - For training mode, PE = 2*U-1; - For infer mode, PE = L+2*U-1. + For training mode, PE = 2 * U - 1; + For inference mode, PE = L + 2 * U - 1. left_context_key (torch,Tensor, optional): Cached attention key of left context from preceding computation, - with shape (L, B, D). + with shape (L, B, D). It is used for inference mode. left_context_val (torch.Tensor, optional): Cached attention value of left context from preceding computation, - with shape (L, B, D). + with shape (L, B, D). It is used for inference mode. Returns: A tuple containing 4 tensors: @@ -355,23 +344,21 @@ class EmformerAttention(nn.Module): R = right_context.size(0) M = memory.size(0) - # Compute query with [right context, utterance, summary]. + # compute query with [right context, utterance, summary]. query = self.emb_to_query( torch.cat([right_context, utterance, summary]) ) - # Compute key and value with [mems, right context, utterance]. + # compute key and value with [mems, right context, utterance]. key, value = self.emb_to_key_value( torch.cat([memory, right_context, utterance]) ).chunk(chunks=2, dim=2) if left_context_key is not None and left_context_val is not None: - # This is for inference mode. Now compute key and value with + # compute key and value with # [mems, right context, left context, uttrance] - key = torch.cat( - [key[: M + R], left_context_key, key[M + R :]] # noqa - ) + key = torch.cat([key[: M + R], left_context_key, key[M + R :]]) value = torch.cat( - [value[: M + R], left_context_val, value[M + R :]] # noqa + [value[: M + R], left_context_val, value[M + R :]] ) Q = query.size(0) KV = key.size(0) @@ -381,12 +368,14 @@ class EmformerAttention(nn.Module): .view(KV, B * self.nhead, self.head_dim) .transpose(0, 1) for tensor in [key, value] - ] # (B * nhead, KV, head_dim) + ] # both of shape (B * nhead, KV, head_dim) reshaped_query = query.contiguous().view( Q, B, self.nhead, self.head_dim ) - # compute attention matrix ac + # compute attention score + # first compute attention matrix a and matrix c + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # noqa query_with_bais_u = ( (reshaped_query + self.pos_bias_u) .view(Q, B * self.nhead, self.head_dim) @@ -396,7 +385,9 @@ class EmformerAttention(nn.Module): query_with_bais_u, reshaped_key.transpose(1, 2) ) # (B * nhead, Q, KV) - # compute attention matrix bd + # second, compute attention matrix b and matrix d + # relative positional encoding is applied on the part of attention + # between query: [utterance] -> key, value: [left_context, utterance] utterance_with_bais_v = ( reshaped_query[R : R + U] + self.pos_bias_v ).permute(1, 2, 0, 3) @@ -416,10 +407,10 @@ class EmformerAttention(nn.Module): matrix_bd_utterance = torch.matmul( utterance_with_bais_v, pos_emb.transpose(-2, -1) ) # (B, nhead, U, PE) - # rel-shift - matrix_bd_utterance = self._rel_shift( - matrix_bd_utterance - ) # (B, nhead, U, U or L + U) + # rel-shift operation + matrix_bd_utterance = self._rel_shift(matrix_bd_utterance) + # (B, nhead, U, U) for training mode; + # (B, nhead, U, L + U) for inference mode. matrix_bd_utterance = matrix_bd_utterance.contiguous().view( B * self.nhead, U, -1 ) @@ -428,25 +419,25 @@ class EmformerAttention(nn.Module): attention_weights = (matrix_ac + matrix_bd) * self.scaling - # Compute padding mask + # compute padding mask if B == 1: padding_mask = None else: padding_mask = make_pad_mask(KV - U + lengths) - # Compute attention probabilities. + # compute attention probabilities attention_probs = self._gen_attention_probs( attention_weights, attention_mask, padding_mask ) - # Compute attention. + # compute attention outputs attention = torch.bmm(attention_probs, reshaped_value) assert attention.shape == (B * self.nhead, Q, self.head_dim) attention = ( attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) ) - # Apply output projection. + # apply output projection outputs = self.out_proj(attention) output_right_context_utterance = outputs[: R + U] @@ -487,7 +478,7 @@ class EmformerAttention(nn.Module): right_context (torch.Tensor): Right context frames, with shape (R, B, D). summary (torch.Tensor): - Summary elements, with shape (S, B, D). + Summary elements with shape (S, B, D) or an empty tensor. memory (torch.Tensor): Memory elements, with shape (M, B, D). attention_mask (torch.Tensor): @@ -495,7 +486,7 @@ class EmformerAttention(nn.Module): with shape (Q, KV), where Q = R + U + S, KV = M + R + U. pos_emb (torch.Tensor): Position encoding embedding, with shape (PE, D). - For training mode, P = 2*U-1. + where PE = 2 * U - 1. Returns: A tuple containing 2 tensors: @@ -549,7 +540,7 @@ class EmformerAttention(nn.Module): right_context (torch.Tensor): Right context frames, with shape (R, B, D). summary (torch.Tensor): - Summary element, with shape (1, B, D), or empty. + Summary element with shape (1, B, D), or an empty tensor. memory (torch.Tensor): Memory elements, with shape (M, B, D). left_context_key (torch,Tensor): @@ -571,19 +562,20 @@ class EmformerAttention(nn.Module): - attention value of left context and utterance, which would be cached for next computation, with shape (L + U, B, D). """ + U = utterance.size(0) + R = right_context.size(0) + L = left_context_key.size(0) + S = summary.size(0) + M = memory.size(0) + # query: [right context, utterance, summary] - Q = right_context.size(0) + utterance.size(0) + summary.size(0) + Q = R + U + S # key, value: [memory, right context, left context, uttrance] - KV = ( - memory.size(0) - + right_context.size(0) # noqa - + left_context_key.size(0) # noqa - + utterance.size(0) # noqa - ) + KV = M + R + L + U attention_mask = torch.zeros(Q, KV).to( dtype=torch.bool, device=utterance.device ) - # Disallow attention bettween the summary vector with the memory bank + # disallow attention bettween the summary vector with the memory bank attention_mask[-1, : memory.size(0)] = True ( output_right_context_utterance, @@ -601,12 +593,11 @@ class EmformerAttention(nn.Module): left_context_key=left_context_key, left_context_val=left_context_val, ) - right_context_end_idx = memory.size(0) + right_context.size(0) return ( output_right_context_utterance, output_memory, - key[right_context_end_idx:], - value[right_context_end_idx:], + key[M + R :], + value[M + R :], ) @@ -656,6 +647,7 @@ class EmformerLayer(nn.Module): self.attention = EmformerAttention( embed_dim=d_model, nhead=nhead, + dropout=dropout, tanh_on_mem=tanh_on_mem, negative_inf=negative_inf, ) @@ -756,9 +748,9 @@ class EmformerLayer(nn.Module): layer_norm_input = self.layer_norm_input( torch.cat([right_context, utterance]) ) - right_context_end_idx = right_context.size(0) - layer_norm_utterance = layer_norm_input[right_context_end_idx:] - layer_norm_right_context = layer_norm_input[:right_context_end_idx] + R = right_context.size(0) + layer_norm_utterance = layer_norm_input[R:] + layer_norm_right_context = layer_norm_input[:R] return layer_norm_utterance, layer_norm_right_context def _apply_post_attention_ffn_layer_norm( @@ -768,18 +760,18 @@ class EmformerLayer(nn.Module): right_context: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Apply feed forward and layer normalization after attention.""" - # Apply residual connection between input and attention output. + # apply residual connection between input and attention output. result = self.dropout(output_right_context_utterance) + torch.cat( [right_context, utterance] ) - # Apply feedforward module and residual connection. + # apply feedforward module and residual connection. result = self.pos_ff(result) + result - # Apply layer normalization for output. + # apply layer normalization for output. result = self.layer_norm_output(result) - right_context_end_idx = right_context.size(0) - output_utterance = result[right_context_end_idx:] - output_right_context = result[:right_context_end_idx] + R = right_context.size(0) + output_utterance = result[R:] + output_right_context = result[:R] return output_utterance, output_right_context def _apply_attention_forward( @@ -796,7 +788,6 @@ class EmformerLayer(nn.Module): raise ValueError( "attention_mask must be not None in non-infer mode. " ) - if self.use_memory: summary = self.summary_op(utterance.permute(1, 2, 0)).permute( 2, 0, 1 @@ -851,8 +842,10 @@ class EmformerLayer(nn.Module): summary = torch.empty(0).to( dtype=utterance.dtype, device=utterance.device ) - # pos_emb is of shape [PE, D], PE = L + 2 * U - 1, - # the relative distance j - i of key(j) and query(i) is in range of [-(L + U - 1), (U - 1)] # noqa + # pos_emb is of shape [PE, D], where PE = L + 2 * U - 1, + # for query of [utterance] (i), key-value [left_context, utterance] (j), + # the max relative distance i - j is L + U - 1 + # the min relative distance i - j is -(U - 1) L = left_context_key.size(0) # L <= left_context_length U = utterance.size(0) PE = L + 2 * U - 1 @@ -916,8 +909,8 @@ class EmformerLayer(nn.Module): Attention mask for underlying attention module, with shape (Q, KV), where Q = R + U + S, KV = M + R + U. pos_emb (torch.Tensor): - Position encoding embedding, with shape (PE, D). - For training mode, P = 2*U-1. + Position encoding embedding, with shape (PE, D), + where PE = 2 * U - 1. Returns: A tuple containing 3 tensors: @@ -987,8 +980,8 @@ class EmformerLayer(nn.Module): List of tensors representing layer internal state generated in preceding computation. (default=None) pos_emb (torch.Tensor): - Position encoding embedding, with shape (PE, D). - For infer mode, PE = L+2*U-1. + Position encoding embedding, with shape (PE, D), + where PE = L + 2 * U - 1. Returns: (Tensor, Tensor, List[torch.Tensor], Tensor): @@ -1073,7 +1066,6 @@ class EmformerEncoder(nn.Module): left_context_length: int = 0, right_context_length: int = 0, max_memory_size: int = 0, - weight_init_scale_strategy: str = "depthwise", tanh_on_mem: bool = False, negative_inf: float = -1e8, ): @@ -1104,6 +1096,8 @@ class EmformerEncoder(nn.Module): ] ) + self.encoder_pos = RelPositionalEncoding(d_model, dropout) + self.left_context_length = left_context_length self.right_context_length = right_context_length self.chunk_length = chunk_length @@ -1246,10 +1240,7 @@ class EmformerEncoder(nn.Module): return attention_mask def forward( - self, - x: torch.Tensor, - lengths: torch.Tensor, - pos_emb: torch.Tensor, + self, x: torch.Tensor, lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass for training and non-streaming inference. @@ -1265,9 +1256,6 @@ class EmformerEncoder(nn.Module): With shape (B,) and i-th element representing number of valid utterance frames for i-th batch element in x, which contains the right_context at the end. - pos_emb (torch.Tensor): - Position encoding embedding, with shape (PE, D). - For training mode, P = 2*U-1. Returns: A tuple of 2 tensors: @@ -1275,8 +1263,11 @@ class EmformerEncoder(nn.Module): - output_lengths, with shape (B,), without containing the right_context at the end. """ + U = x.size(0) - self.right_context_length + x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U) + right_context = self._gen_right_context(x) - utterance = x[: x.size(0) - self.right_context_length] + utterance = x[:U] output_lengths = torch.clamp(lengths - self.right_context_length, min=0) attention_mask = self._gen_attention_mask(utterance) memory = ( @@ -1286,6 +1277,7 @@ class EmformerEncoder(nn.Module): if self.use_memory else torch.empty(0).to(dtype=x.dtype, device=x.device) ) + output = utterance for layer in self.emformer_layers: output, right_context, memory = layer( @@ -1304,7 +1296,6 @@ class EmformerEncoder(nn.Module): self, x: torch.Tensor, lengths: torch.Tensor, - pos_emb: torch.Tensor, states: Optional[List[List[torch.Tensor]]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]: """Forward pass for streaming inference. @@ -1325,9 +1316,6 @@ class EmformerEncoder(nn.Module): Cached states from proceeding chunk's computation, where each element (List[torch.Tensor]) corresponding to each emformer layer. (default: None) - pos_emb (torch.Tensor): - Position encoding embedding, with shape (PE, D). - For infer mode, PE = L+2*U-1. Returns: (Tensor, Tensor, List[List[torch.Tensor]]): @@ -1341,9 +1329,12 @@ class EmformerEncoder(nn.Module): f"expected size of {self.chunk_length + self.right_context_length} " f"for dimension 1 of x, but got {x.size(1)}." ) - right_context_start_idx = x.size(0) - self.right_context_length - right_context = x[right_context_start_idx:] - utterance = x[:right_context_start_idx] + pos_len = self.chunk_length + self.left_context_length + neg_len = self.chunk_length + x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len) + + right_context = x[self.chunk_length :] + utterance = x[: self.chunk_length] output_lengths = torch.clamp(lengths - self.right_context_length, min=0) memory = ( self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) @@ -1383,7 +1374,6 @@ class Emformer(EncoderInterface): left_context_length: int = 0, right_context_length: int = 0, max_memory_size: int = 0, - weight_init_scale_strategy: str = "depthwise", tanh_on_mem: bool = False, negative_inf: float = -1e8, ): @@ -1416,8 +1406,6 @@ class Emformer(EncoderInterface): else: self.encoder_embed = Conv2dSubsampling(num_features, d_model) - self.encoder_pos = RelPositionalEncoding(d_model, dropout) - self.encoder = EmformerEncoder( chunk_length // 4, d_model, @@ -1429,7 +1417,6 @@ class Emformer(EncoderInterface): left_context_length=left_context_length // 4, right_context_length=right_context_length // 4, max_memory_size=max_memory_size, - weight_init_scale_strategy=weight_init_scale_strategy, tanh_on_mem=tanh_on_mem, negative_inf=negative_inf, ) @@ -1465,10 +1452,6 @@ class Emformer(EncoderInterface): right_context at the end. """ x = self.encoder_embed(x) - - # TODO: The length computation in the encoder class should be moved here. # noqa - U = x.size(1) - self.right_context_length // 4 - x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) # Caution: We assume the subsampling factor is 4! @@ -1477,7 +1460,7 @@ class Emformer(EncoderInterface): x_lens = ((x_lens - 1) // 2 - 1) // 2 assert x.size(0) == x_lens.max().item() - output, output_lengths = self.encoder(x, x_lens, pos_emb) # (T, N, C) + output, output_lengths = self.encoder(x, x_lens) # (T, N, C) logits = self.encoder_output_layer(output) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) @@ -1518,12 +1501,6 @@ class Emformer(EncoderInterface): - updated states from current chunk's computation. """ x = self.encoder_embed(x) - - # TODO: The length computation in the encoder class should be moved here. # noqa - pos_len = self.chunk_length // 4 + self.left_context_length // 4 - neg_len = self.chunk_length // 4 - x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) # Caution: We assume the subsampling factor is 4! @@ -1533,7 +1510,7 @@ class Emformer(EncoderInterface): assert x.size(0) == x_lens.max().item() output, output_lengths, output_states = self.encoder.infer( - x, x_lens, pos_emb, states + x, x_lens, states ) # (T, N, C) logits = self.encoder_output_layer(output) diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py index 4cbb43f81..b2b1000cc 100644 --- a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py @@ -199,7 +199,6 @@ def test_emformer_encoder_forward(): chunk_length = 4 right_context_length = 2 left_context_length = 2 - left_context_length = 2 num_chunks = 3 U = num_chunks * chunk_length @@ -223,10 +222,8 @@ def test_emformer_encoder_forward(): x = torch.randn(U + right_context_length, B, D) lengths = torch.randint(1, U + right_context_length + 1, (B,)) lengths[0] = U + right_context_length - PE = 2 * U - 1 - pos_emb = torch.randn(PE, D) - output, output_lengths = encoder(x, lengths, pos_emb) + output, output_lengths = encoder(x, lengths) assert output.shape == (U, B, D) assert torch.equal( output_lengths, torch.clamp(lengths - right_context_length, min=0) @@ -266,11 +263,7 @@ def test_emformer_encoder_infer(): 1, chunk_length + right_context_length + 1, (B,) ) lengths[0] = chunk_length + right_context_length - PE = left_context_length + 2 * chunk_length - 1 - pos_emb = torch.randn(PE, D) - output, output_lengths, states = encoder.infer( - x, lengths, pos_emb, states - ) + output, output_lengths, states = encoder.infer(x, lengths, states) assert output.shape == (chunk_length, B, D) assert torch.equal( output_lengths, @@ -383,6 +376,7 @@ def test_emformer_infer(): def test_emformer_attention_forward_infer_consistency(): + # TODO: delete from emformer import EmformerEncoder chunk_length = 4 @@ -474,7 +468,7 @@ def test_emformer_layer_forward_infer_consistency(): chunk_length = 4 num_chunks = 3 U = chunk_length * num_chunks - L, R = 1, 2 + left_context_length, right_context_length = 1, 2 D = 256 num_encoder_layers = 1 memory_sizes = [0, 3] @@ -485,18 +479,22 @@ def test_emformer_layer_forward_infer_consistency(): d_model=D, dim_feedforward=1024, num_encoder_layers=num_encoder_layers, - left_context_length=L, - right_context_length=R, + left_context_length=left_context_length, + right_context_length=right_context_length, max_memory_size=M, dropout=0.1, ) encoder.eval() encoder_layer = encoder.emformer_layers[0] + encoder_pos = encoder.encoder_pos - x = torch.randn(U + R, 1, D) + x = torch.randn(U + right_context_length, 1, D) + + # training mode with full utterance + x_forward, pos_emb = encoder_pos(x, U, U) lengths = torch.tensor([U]) - right_context = encoder._gen_right_context(x) - utterance = x[: x.size(0) - R] + right_context = encoder._gen_right_context(x_forward) + utterance = x_forward[:U] attention_mask = encoder._gen_attention_mask(utterance) memory = ( encoder.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ @@ -515,15 +513,20 @@ def test_emformer_layer_forward_infer_consistency(): right_context, memory, attention_mask, + pos_emb, ) state = None for chunk_idx in range(num_chunks): start_idx = chunk_idx * chunk_length end_idx = start_idx + chunk_length - chunk = x[start_idx:end_idx] - chunk_right_context = x[end_idx : end_idx + R] # noqa - chunk_length = torch.tensor([chunk_length]) + cur_x, pos_emb = encoder_pos( + x[start_idx : end_idx + right_context_length], + pos_len=chunk_length + left_context_length, + neg_len=chunk_length, + ) + chunk = cur_x[:chunk_length] + chunk_right_context = cur_x[chunk_length:] chunk_memory = ( encoder.init_memory_op(chunk.permute(1, 2, 0)).permute(2, 0, 1) if encoder.use_memory @@ -536,9 +539,10 @@ def test_emformer_layer_forward_infer_consistency(): state, ) = encoder_layer.infer( chunk, - chunk_length, + torch.tensor([chunk_length]), chunk_right_context, chunk_memory, + pos_emb, state, ) forward_output_chunk = forward_output_utterance[start_idx:end_idx] @@ -551,7 +555,7 @@ def test_emformer_layer_forward_infer_consistency(): def test_emformer_encoder_forward_infer_consistency(): - from emformer import EmformerEncoder, RelPositionalEncoding + from emformer import EmformerEncoder chunk_length = 4 num_chunks = 3 @@ -573,28 +577,22 @@ def test_emformer_encoder_forward_infer_consistency(): dropout=0.1, ) encoder.eval() - encoder_pos = RelPositionalEncoding(D, dropout_rate=0) x = torch.randn(U + right_context_length, 1, D) lengths = torch.tensor([U + right_context_length]) - _, pos_emb = encoder_pos(x, U, U) - forward_output, forward_output_lengths = encoder(x, lengths, pos_emb) + # training mode with full utterance + forward_output, forward_output_lengths = encoder(x, lengths) + # streaming inference mode with individual chunks states = None - _, pos_emb = encoder_pos( - x, chunk_length + left_context_length, chunk_length - ) for chunk_idx in range(num_chunks): start_idx = chunk_idx * chunk_length end_idx = start_idx + chunk_length chunk = x[start_idx : end_idx + right_context_length] # noqa chunk_length = torch.tensor([chunk_length]) infer_output_chunk, infer_output_lengths, states = encoder.infer( - chunk, - chunk_length, - pos_emb, - states, + chunk, chunk_length, states ) forward_output_chunk = forward_output[start_idx:end_idx] assert torch.allclose( @@ -615,7 +613,7 @@ def test_emformer_infer_batch_single_consistency(): chunk_length = 8 num_chunks = 3 U = num_chunks * chunk_length - L, R = 128, 4 + left_context_length, right_context_length = 128, 4 B, D = 2, 256 num_encoder_layers = 2 for use_memory in [True, False]: @@ -630,8 +628,8 @@ def test_emformer_infer_batch_single_consistency(): subsampling_factor=4, d_model=D, num_encoder_layers=num_encoder_layers, - left_context_length=L, - right_context_length=R, + left_context_length=left_context_length, + right_context_length=right_context_length, max_memory_size=M, vgg_frontend=False, ) @@ -689,20 +687,25 @@ def test_emformer_infer_batch_single_consistency(): ], ) - x = torch.randn(B, U + R + 3, num_features) + x = torch.randn(B, U + right_context_length + 3, num_features) + + # batch-wise inference batch_logits = [] batch_states = [] states = None for chunk_idx in range(num_chunks): start_idx = chunk_idx * chunk_length end_idx = start_idx + chunk_length - chunk = x[:, start_idx : end_idx + R + 3] # noqa - lengths = torch.tensor([chunk_length + R + 3]).expand(B) + chunk = x[:, start_idx : end_idx + right_context_length + 3] # noqa + lengths = torch.tensor( + [chunk_length + right_context_length + 3] + ).expand(B) logits, output_lengths, states = model.infer(chunk, lengths, states) batch_logits.append(logits) batch_states.append(save_states(states)) batch_logits = torch.cat(batch_logits, dim=1) + # single-wise inference single_logits = [] for sample_idx in range(B): sample = x[sample_idx : sample_idx + 1] # noqa @@ -711,17 +714,21 @@ def test_emformer_infer_batch_single_consistency(): for chunk_idx in range(num_chunks): start_idx = chunk_idx * chunk_length end_idx = start_idx + chunk_length - chunk = sample[:, start_idx : end_idx + R + 3] # noqa - lengths = torch.tensor([chunk_length + R + 3]) + chunk = sample[ + :, start_idx : end_idx + right_context_length + 3 + ] + lengths = torch.tensor( + [chunk_length + right_context_length + 3] + ) logits, output_lengths, states = model.infer( chunk, lengths, states ) chunk_logits.append(logits) - assert_states_equal(batch_states[chunk_idx], states, sample_idx) chunk_logits = torch.cat(chunk_logits, dim=1) single_logits.append(chunk_logits) + single_logits = torch.cat(single_logits, dim=0) assert torch.allclose(batch_logits, single_logits, atol=1e-5, rtol=0.0) @@ -734,7 +741,7 @@ def test_emformer_infer_states_stack(): output_dim = 1000 chunk_length = 8 U = chunk_length - L, R = 128, 4 + left_context_length, right_context_length = 128, 4 B, D = 2, 256 num_encoder_layers = 2 for use_memory in [True, False]: @@ -749,14 +756,14 @@ def test_emformer_infer_states_stack(): subsampling_factor=4, d_model=D, num_encoder_layers=num_encoder_layers, - left_context_length=L, - right_context_length=R, + left_context_length=left_context_length, + right_context_length=right_context_length, max_memory_size=M, vgg_frontend=False, ) - x = torch.randn(B, U + R + 3, num_features) - x_lens = torch.full((B,), U + R + 3) + x = torch.randn(B, U + right_context_length + 3, num_features) + x_lens = torch.full((B,), U + right_context_length + 3) logits, output_lengths, states = model.infer( x, x_lens, @@ -790,8 +797,8 @@ if __name__ == "__main__": test_emformer_forward() test_emformer_infer() # test_emformer_attention_forward_infer_consistency() - # test_emformer_layer_forward_infer_consistency() + test_emformer_layer_forward_infer_consistency() test_emformer_encoder_forward_infer_consistency() - # test_emformer_infer_batch_single_consistency() - # test_emformer_infer_states_stack() + test_emformer_infer_batch_single_consistency() + test_emformer_infer_states_stack() test_rel_positional_encoding()