diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py index 67e9f5891..9eb5b966f 100644 --- a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py @@ -85,8 +85,6 @@ class EmformerAttention(nn.Module): Embedding dimension. nhead (int): Number of attention heads in each Emformer layer. - dropout (float, optional): - Dropout probability. (Default: 0.0) weight_init_gain (float or None, optional): Scale factor to apply when initializing attention module parameters. (Default: ``None``) @@ -100,7 +98,6 @@ class EmformerAttention(nn.Module): self, embed_dim: int, nhead: int, - dropout: float = 0.0, weight_init_gain: Optional[float] = None, tanh_on_mem: bool = False, negative_inf: float = -1e8, @@ -115,7 +112,6 @@ class EmformerAttention(nn.Module): self.embed_dim = embed_dim self.nhead = nhead - self.dropout = dropout self.tanh_on_mem = tanh_on_mem self.negative_inf = negative_inf @@ -183,9 +179,7 @@ class EmformerAttention(nn.Module): attention_probs = nn.functional.softmax( attention_weights_float, dim=-1 ).type_as(attention_weights) - # attention_probs = nn.functional.dropout( - # attention_probs, p=float(self.dropout), training=self.training - # ) + return attention_probs def _forward_impl( @@ -512,7 +506,6 @@ class EmformerLayer(nn.Module): self.attention = EmformerAttention( embed_dim=d_model, nhead=nhead, - dropout=dropout, weight_init_gain=weight_init_gain, tanh_on_mem=tanh_on_mem, negative_inf=negative_inf, diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py index 5e08640d3..abc023bb7 100644 --- a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py @@ -362,8 +362,9 @@ def test_emformer_attention_forward_infer_consistency(): left_context_length=L, right_context_length=R, max_memory_size=M, - dropout=0.0, + dropout=0.1, ) + encoder.eval() encoder_layer = encoder.emformer_layers[0] x = torch.randn(U + R, 1, D) @@ -415,12 +416,15 @@ def test_emformer_attention_forward_infer_consistency(): chunk_memory, state, ) - infer_output_utterance = infer_output_right_context_utterance[ + infer_output_chunk = infer_output_right_context_utterance[ chunk_right_context.size(0) : # noqa ] - print( - infer_output_utterance - - forward_output_utterance[start_idx:end_idx] + forward_output_chunk = forward_output_utterance[start_idx:end_idx] + assert torch.allclose( + infer_output_chunk, + forward_output_chunk, + atol=1e-6, + rtol=0.0, ) @@ -444,8 +448,9 @@ def test_emformer_layer_forward_infer_consistency(): left_context_length=L, right_context_length=R, max_memory_size=M, - dropout=0.0, + dropout=0.1, ) + encoder.eval() encoder_layer = encoder.emformer_layers[0] x = torch.randn(U + R, 1, D) @@ -485,7 +490,7 @@ def test_emformer_layer_forward_infer_consistency(): else torch.empty(0).to(dtype=x.dtype, device=x.device) ) ( - infer_output_utterance, + infer_output_chunk, infer_right_context, infer_output_memory, state, @@ -496,9 +501,12 @@ def test_emformer_layer_forward_infer_consistency(): chunk_memory, state, ) - print( - infer_output_utterance - - forward_output_utterance[start_idx:end_idx] + forward_output_chunk = forward_output_utterance[start_idx:end_idx] + assert torch.allclose( + infer_output_chunk, + forward_output_chunk, + atol=1e-5, + rtol=0.0, ) @@ -522,8 +530,9 @@ def test_emformer_encoder_forward_infer_consistency(): left_context_length=L, right_context_length=R, max_memory_size=M, - dropout=0.0, + dropout=0.1, ) + encoder.eval() x = torch.randn(U + R, 1, D) lengths = torch.tensor([U + R]) @@ -537,23 +546,152 @@ def test_emformer_encoder_forward_infer_consistency(): chunk = x[start_idx : end_idx + R] # noqa chunk_right_context = x[end_idx : end_idx + R] # noqa chunk_length = torch.tensor([chunk_length]) - infer_output, infer_output_lengths, states = encoder.infer( + infer_output_chunk, infer_output_lengths, states = encoder.infer( chunk, chunk_length, states, ) - print(infer_output - forward_output[start_idx:end_idx]) + forward_output_chunk = forward_output[start_idx:end_idx] + assert torch.allclose( + infer_output_chunk, + forward_output_chunk, + atol=1e-5, + rtol=0.0, + ) + + +def test_emformer_infer_batch_single_consistency(): + """Test consistency of cached states and output logits between single + utterance inference and batch inference.""" + from emformer import Emformer + + num_features = 80 + output_dim = 1000 + chunk_length = 8 + num_chunks = 3 + U = num_chunks * chunk_length + L, R = 128, 4 + B, D = 2, 256 + num_encoder_layers = 2 + for use_memory in [True, False]: + if use_memory: + M = 3 + else: + M = 0 + model = Emformer( + num_features=num_features, + output_dim=output_dim, + chunk_length=chunk_length, + subsampling_factor=4, + d_model=D, + num_encoder_layers=num_encoder_layers, + left_context_length=L, + right_context_length=R, + max_memory_size=M, + vgg_frontend=False, + ) + model.eval() + + def save_states(states): + saved_states = [] + for layer_idx in range(len(states)): + layer_state = [] + layer_state.append(states[layer_idx][0].clone()) # memory + layer_state.append( + states[layer_idx][1].clone() + ) # left_context_key + layer_state.append( + states[layer_idx][2].clone() + ) # left_context_val + layer_state.append(states[layer_idx][3].clone()) # past_length + saved_states.append(layer_state) + return saved_states + + def assert_states_equal(saved_states, states, sample_idx): + for layer_idx in range(len(saved_states)): + # assert eqaul memory + assert torch.allclose( + states[layer_idx][0], + saved_states[layer_idx][0][ + :, sample_idx : sample_idx + 1 # noqa + ], + atol=1e-5, + rtol=0.0, + ) + # assert equal left_context_key + assert torch.allclose( + states[layer_idx][1], + saved_states[layer_idx][1][ + :, sample_idx : sample_idx + 1 # noqa + ], + atol=1e-5, + rtol=0.0, + ) + # assert equal left_context_val + assert torch.allclose( + states[layer_idx][2], + saved_states[layer_idx][2][ + :, sample_idx : sample_idx + 1 # noqa + ], + atol=1e-5, + rtol=0.0, + ) + # assert eqaul past_length + assert torch.equal( + states[layer_idx][3], + saved_states[layer_idx][3][ + :, sample_idx : sample_idx + 1 # noqa + ], + ) + + x = torch.randn(B, U + R + 3, num_features) + batch_logits = [] + batch_states = [] + states = None + for chunk_idx in range(num_chunks): + start_idx = chunk_idx * chunk_length + end_idx = start_idx + chunk_length + chunk = x[:, start_idx : end_idx + R + 3] # noqa + lengths = torch.tensor([chunk_length + R + 3]).expand(B) + logits, output_lengths, states = model.infer(chunk, lengths, states) + batch_logits.append(logits) + batch_states.append(save_states(states)) + batch_logits = torch.cat(batch_logits, dim=1) + + single_logits = [] + for sample_idx in range(B): + sample = x[sample_idx : sample_idx + 1] # noqa + chunk_logits = [] + states = None + for chunk_idx in range(num_chunks): + start_idx = chunk_idx * chunk_length + end_idx = start_idx + chunk_length + chunk = sample[:, start_idx : end_idx + R + 3] # noqa + lengths = torch.tensor([chunk_length + R + 3]) + logits, output_lengths, states = model.infer( + chunk, lengths, states + ) + chunk_logits.append(logits) + + assert_states_equal(batch_states[chunk_idx], states, sample_idx) + + chunk_logits = torch.cat(chunk_logits, dim=1) + single_logits.append(chunk_logits) + single_logits = torch.cat(single_logits, dim=0) + + assert torch.allclose(batch_logits, single_logits, atol=1e-5, rtol=0.0) if __name__ == "__main__": - # test_emformer_attention_forward() - # test_emformer_attention_infer() - # test_emformer_layer_forward() - # test_emformer_layer_infer() - # test_emformer_encoder_forward() - # test_emformer_encoder_infer() - # test_emformer_forward() - # test_emformer_infer() - # test_emformer_attention_forward_infer_consistency() - # test_emformer_layer_forward_infer_consistency() + test_emformer_attention_forward() + test_emformer_attention_infer() + test_emformer_layer_forward() + test_emformer_layer_infer() + test_emformer_encoder_forward() + test_emformer_encoder_infer() + test_emformer_forward() + test_emformer_infer() + test_emformer_attention_forward_infer_consistency() + test_emformer_layer_forward_infer_consistency() test_emformer_encoder_forward_infer_consistency() + test_emformer_infer_batch_single_consistency()