diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py index 4ba19ebae..67e9f5891 100644 --- a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py @@ -183,9 +183,9 @@ class EmformerAttention(nn.Module): attention_probs = nn.functional.softmax( attention_weights_float, dim=-1 ).type_as(attention_weights) - attention_probs = nn.functional.dropout( - attention_probs, p=float(self.dropout), training=self.training - ) + # attention_probs = nn.functional.dropout( + # attention_probs, p=float(self.dropout), training=self.training + # ) return attention_probs def _forward_impl( @@ -955,16 +955,15 @@ class EmformerEncoder(nn.Module): def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: """Hard copy each chunk's right context and concat them.""" T = x.shape[0] - num_segs = math.ceil( + num_chunks = math.ceil( (T - self.right_context_length) / self.chunk_length ) right_context_blocks = [] - for seg_idx in range(num_segs - 1): + for seg_idx in range(num_chunks - 1): start = (seg_idx + 1) * self.chunk_length end = start + self.right_context_length right_context_blocks.append(x[start:end]) - last_right_context_start_idx = T - self.right_context_length - right_context_blocks.append(x[last_right_context_start_idx:]) + right_context_blocks.append(x[T - self.right_context_length :]) # noqa return torch.cat(right_context_blocks) def _gen_attention_mask_col_widths( diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py index 56cf2035e..5e08640d3 100644 --- a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py @@ -342,12 +342,218 @@ def test_emformer_infer(): ) +def test_emformer_attention_forward_infer_consistency(): + from emformer import EmformerEncoder + + chunk_length = 4 + num_chunks = 3 + U = chunk_length * num_chunks + L, R = 1, 2 + D = 256 + num_encoder_layers = 1 + memory_sizes = [0, 3] + + for M in memory_sizes: + encoder = EmformerEncoder( + chunk_length=chunk_length, + d_model=D, + dim_feedforward=1024, + num_encoder_layers=num_encoder_layers, + left_context_length=L, + right_context_length=R, + max_memory_size=M, + dropout=0.0, + ) + encoder_layer = encoder.emformer_layers[0] + + x = torch.randn(U + R, 1, D) + lengths = torch.tensor([U]) + right_context = encoder._gen_right_context(x) + utterance = x[: x.size(0) - R] + attention_mask = encoder._gen_attention_mask(utterance) + memory = ( + encoder.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :-1 + ] + if encoder.use_memory + else torch.empty(0).to(dtype=x.dtype, device=x.device) + ) + ( + forward_output_right_context_utterance, + forward_output_memory, + ) = encoder_layer._apply_attention_forward( + utterance, + lengths, + right_context, + memory, + attention_mask, + ) + forward_output_utterance = forward_output_right_context_utterance[ + right_context.size(0) : # noqa + ] + + state = None + for chunk_idx in range(num_chunks): + start_idx = chunk_idx * chunk_length + end_idx = start_idx + chunk_length + chunk = x[start_idx:end_idx] + chunk_right_context = x[end_idx : end_idx + R] # noqa + chunk_length = torch.tensor([chunk_length]) + chunk_memory = ( + encoder.init_memory_op(chunk.permute(1, 2, 0)).permute(2, 0, 1) + if encoder.use_memory + else torch.empty(0).to(dtype=x.dtype, device=x.device) + ) + ( + infer_output_right_context_utterance, + infer_output_memory, + state, + ) = encoder_layer._apply_attention_infer( + chunk, + chunk_length, + chunk_right_context, + chunk_memory, + state, + ) + infer_output_utterance = infer_output_right_context_utterance[ + chunk_right_context.size(0) : # noqa + ] + print( + infer_output_utterance + - forward_output_utterance[start_idx:end_idx] + ) + + +def test_emformer_layer_forward_infer_consistency(): + from emformer import EmformerEncoder + + chunk_length = 4 + num_chunks = 3 + U = chunk_length * num_chunks + L, R = 1, 2 + D = 256 + num_encoder_layers = 1 + memory_sizes = [0, 3] + + for M in memory_sizes: + encoder = EmformerEncoder( + chunk_length=chunk_length, + d_model=D, + dim_feedforward=1024, + num_encoder_layers=num_encoder_layers, + left_context_length=L, + right_context_length=R, + max_memory_size=M, + dropout=0.0, + ) + encoder_layer = encoder.emformer_layers[0] + + x = torch.randn(U + R, 1, D) + lengths = torch.tensor([U]) + right_context = encoder._gen_right_context(x) + utterance = x[: x.size(0) - R] + attention_mask = encoder._gen_attention_mask(utterance) + memory = ( + encoder.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :-1 + ] + if encoder.use_memory + else torch.empty(0).to(dtype=x.dtype, device=x.device) + ) + ( + forward_output_utterance, + forward_output_right_context, + forward_output_memory, + ) = encoder_layer( + utterance, + lengths, + right_context, + memory, + attention_mask, + ) + + state = None + for chunk_idx in range(num_chunks): + start_idx = chunk_idx * chunk_length + end_idx = start_idx + chunk_length + chunk = x[start_idx:end_idx] + chunk_right_context = x[end_idx : end_idx + R] # noqa + chunk_length = torch.tensor([chunk_length]) + chunk_memory = ( + encoder.init_memory_op(chunk.permute(1, 2, 0)).permute(2, 0, 1) + if encoder.use_memory + else torch.empty(0).to(dtype=x.dtype, device=x.device) + ) + ( + infer_output_utterance, + infer_right_context, + infer_output_memory, + state, + ) = encoder_layer.infer( + chunk, + chunk_length, + chunk_right_context, + chunk_memory, + state, + ) + print( + infer_output_utterance + - forward_output_utterance[start_idx:end_idx] + ) + + +def test_emformer_encoder_forward_infer_consistency(): + from emformer import EmformerEncoder + + chunk_length = 4 + num_chunks = 3 + U = chunk_length * num_chunks + L, R = 1, 2 + D = 256 + num_encoder_layers = 3 + memory_sizes = [0, 3] + + for M in memory_sizes: + encoder = EmformerEncoder( + chunk_length=chunk_length, + d_model=D, + dim_feedforward=1024, + num_encoder_layers=num_encoder_layers, + left_context_length=L, + right_context_length=R, + max_memory_size=M, + dropout=0.0, + ) + + x = torch.randn(U + R, 1, D) + lengths = torch.tensor([U + R]) + + forward_output, forward_output_lengths = encoder(x, lengths) + + states = None + for chunk_idx in range(num_chunks): + start_idx = chunk_idx * chunk_length + end_idx = start_idx + chunk_length + chunk = x[start_idx : end_idx + R] # noqa + chunk_right_context = x[end_idx : end_idx + R] # noqa + chunk_length = torch.tensor([chunk_length]) + infer_output, infer_output_lengths, states = encoder.infer( + chunk, + chunk_length, + states, + ) + print(infer_output - forward_output[start_idx:end_idx]) + + if __name__ == "__main__": - test_emformer_attention_forward() - test_emformer_attention_infer() - test_emformer_layer_forward() - test_emformer_layer_infer() - test_emformer_encoder_forward() - test_emformer_encoder_infer() - test_emformer_forward() - test_emformer_infer() + # test_emformer_attention_forward() + # test_emformer_attention_infer() + # test_emformer_layer_forward() + # test_emformer_layer_infer() + # test_emformer_encoder_forward() + # test_emformer_encoder_infer() + # test_emformer_forward() + # test_emformer_infer() + # test_emformer_attention_forward_infer_consistency() + # test_emformer_layer_forward_infer_consistency() + test_emformer_encoder_forward_infer_consistency()