diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/beam_search.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/beam_search.py index ce8b04afd..7c936b257 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/beam_search.py @@ -557,7 +557,7 @@ class HypothesisList(object): return ", ".join(s) -def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: +def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: """Return a ragged shape with axes [utt][num_hyps]. Args: @@ -648,7 +648,7 @@ def modified_beam_search( finalized_B = B[batch_size:] + finalized_B B = B[:batch_size] - hyps_shape = _get_hyps_shape(B).to(device) + hyps_shape = get_hyps_shape(B).to(device) A = [list(b) for b in B] B = [HypothesisList() for _ in range(batch_size)] diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py index 9234677a1..bf22e3f2d 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py @@ -302,6 +302,7 @@ def decode_one_chunk( # update cached states of each stream state_list = unstack_states(states) + assert len(streams) == len(state_list) for i, s in enumerate(state_list): streams[i].states = s @@ -358,15 +359,10 @@ def decode_dataset( """ device = next(model.parameters()).device - opts = FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - log_interval = 300 + fbank = create_streaming_feature_extractor() + decode_results = [] streams = [] for num, cut in enumerate(cuts): @@ -382,7 +378,6 @@ def decode_dataset( assert audio.max() <= 1, "Should be normalized to [-1, 1])" samples = torch.from_numpy(audio).squeeze(0) - fbank = create_streaming_feature_extractor() feature = fbank(samples) stream.set_feature(feature) stream.set_ground_truth(cut.supervisions[0].text) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py index 5e7f33137..7d55d8e1e 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py @@ -511,15 +511,75 @@ def test_emformer_infer(): assert conv_cache.shape == (B, D, kernel_size - 1) +def test_state_stack_unstack(): + from emformer import Emformer, stack_states, unstack_states + + num_features = 80 + chunk_length = 32 + encoder_dim = 512 + num_encoder_layers = 2 + kernel_size = 31 + left_context_length = 32 + right_context_length = 8 + memory_size = 32 + batch_size = 2 + + model = Emformer( + num_features=num_features, + chunk_length=chunk_length, + subsampling_factor=4, + d_model=encoder_dim, + num_encoder_layers=num_encoder_layers, + cnn_module_kernel=kernel_size, + left_context_length=left_context_length, + right_context_length=right_context_length, + memory_size=memory_size, + ) + attn_caches = [ + [ + torch.zeros(memory_size, batch_size, encoder_dim), + torch.zeros(left_context_length // 4, batch_size, encoder_dim), + torch.zeros( + left_context_length // 4, + batch_size, + encoder_dim, + ), + ] + for _ in range(num_encoder_layers) + ] + conv_caches = [ + torch.zeros(batch_size, encoder_dim, kernel_size - 1) + for _ in range(num_encoder_layers) + ] + states = [attn_caches, conv_caches] + x = torch.randn(batch_size, 23, num_features) + x_lens = torch.full((batch_size,), 23) + num_processed_frames = torch.full((batch_size,), 0) + y, y_lens, states = model.infer( + x, x_lens, num_processed_frames=num_processed_frames, states=states + ) + + state_list = unstack_states(states) + states2 = stack_states(state_list) + + for ss, ss2 in zip(states[0], states2[0]): + for s, s2 in zip(ss, ss2): + assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}" + + for s, s2 in zip(states[1], states2[1]): + assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}" + + if __name__ == "__main__": - test_emformer_attention_forward() - test_emformer_attention_infer() - test_convolution_module_forward() - test_convolution_module_infer() - test_emformer_encoder_layer_forward() - test_emformer_encoder_layer_infer() - test_emformer_encoder_forward() - test_emformer_encoder_infer() - test_emformer_encoder_forward_infer_consistency() - test_emformer_forward() - test_emformer_infer() + # test_emformer_attention_forward() + # test_emformer_attention_infer() + # test_convolution_module_forward() + # test_convolution_module_infer() + # test_emformer_encoder_layer_forward() + # test_emformer_encoder_layer_infer() + # test_emformer_encoder_forward() + # test_emformer_encoder_infer() + # test_emformer_encoder_forward_infer_consistency() + # test_emformer_forward() + # test_emformer_infer() + test_state_stack_unstack()