diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py index c41c8a3c0..dcf6dc42f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py @@ -55,7 +55,7 @@ def greedy_search( device=device, dtype=torch.int64, ) - # decoder_out is of shape (N, decoder_out_dim) + # decoder_out is of shape (N, 1, decoder_out_dim) decoder_out = model.decoder(decoder_input, need_pad=False) for t in range(T): @@ -93,7 +93,7 @@ def modified_beam_search( model: nn.Module, encoder_out: torch.Tensor, streams: List[DecodeStream], - beam: int = 4, + num_active_paths: int = 4, ) -> None: """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. @@ -105,7 +105,7 @@ def modified_beam_search( the encoder model. streams: A list of stream objects. - beam: + num_active_paths: Number of active paths during the beam search. """ assert encoder_out.ndim == 3, encoder_out.shape @@ -171,7 +171,9 @@ def modified_beam_search( ) for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + topk_log_probs, topk_indexes = ragged_log_probs[i].topk( + num_active_paths + ) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py index 548cbc770..e455627f3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py @@ -17,13 +17,13 @@ """ Usage: -./pruned_transducer_stateless2/streaming_decode.py \ +./pruned_transducer_stateless/streaming_decode.py \ --epoch 28 \ --avg 15 \ --decode-chunk-size 8 \ --left-context 32 \ --right-context 0 \ - --exp-dir ./pruned_transducer_stateless2/exp \ + --exp-dir ./pruned_transducer_stateless/exp \ --decoding_method greedy_search \ --num-decode-streams 1000 """ @@ -125,7 +125,7 @@ def get_parser(): ) parser.add_argument( - "--beam-size", + "--num-active-paths", type=int, default=4, help="""An interger indicating how many candidates we will keep for each @@ -288,7 +288,7 @@ def decode_one_chunk( model=model, streams=decode_streams, encoder_out=encoder_out, - beam=params.beam_size, + num_active_paths=params.num_active_paths, ) else: raise ValueError( @@ -417,7 +417,7 @@ def decode_dataset( f"max_states_{params.max_states}" ) elif params.decoding_method == "modified_beam_search": - key = f"beam_size_{params.beam_size}" + key = f"num_active_paths_{params.num_active_paths}" else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py index 9cd2e7f43..9bcd2f9f9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py @@ -55,7 +55,7 @@ def greedy_search( device=device, dtype=torch.int64, ) - # decoder_out is of shape (N, decoder_out_dim) + # decoder_out is of shape (N, 1, decoder_out_dim) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) @@ -96,7 +96,7 @@ def modified_beam_search( model: nn.Module, encoder_out: torch.Tensor, streams: List[DecodeStream], - beam: int = 4, + num_active_paths: int = 4, ) -> None: """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. @@ -108,7 +108,7 @@ def modified_beam_search( the encoder model. streams: A list of stream objects. - beam: + num_active_paths: Number of active paths during the beam search. """ assert encoder_out.ndim == 3, encoder_out.shape @@ -177,7 +177,9 @@ def modified_beam_search( ) for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + topk_log_probs, topk_indexes = ragged_log_probs[i].topk( + num_active_paths + ) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index b2b482cb4..79963c968 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -125,7 +125,7 @@ def get_parser(): ) parser.add_argument( - "--beam-size", + "--num_active_paths", type=int, default=4, help="""An interger indicating how many candidates we will keep for each @@ -290,7 +290,7 @@ def decode_one_chunk( model=model, streams=decode_streams, encoder_out=encoder_out, - beam=params.beam_size, + num_active_paths=params.num_active_paths, ) else: raise ValueError( @@ -420,7 +420,7 @@ def decode_dataset( f"max_states_{params.max_states}" ) elif params.decoding_method == "modified_beam_search": - key = f"beam_size_{params.beam_size}" + key = f"num_active_paths_{params.num_active_paths}" else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py index ed638bd0a..1976d19a6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py @@ -17,13 +17,13 @@ """ Usage: -./pruned_transducer_stateless2/streaming_decode.py \ +./pruned_transducer_stateless3/streaming_decode.py \ --epoch 28 \ --avg 15 \ --left-context 32 \ --decode-chunk-size 8 \ --right-context 0 \ - --exp-dir ./pruned_transducer_stateless2/exp \ + --exp-dir ./pruned_transducer_stateless3/exp \ --decoding_method greedy_search \ --num-decode-streams 1000 """ @@ -126,7 +126,7 @@ def get_parser(): ) parser.add_argument( - "--beam-size", + "--num_active_paths", type=int, default=4, help="""An interger indicating how many candidates we will keep for each @@ -291,7 +291,7 @@ def decode_one_chunk( model=model, streams=decode_streams, encoder_out=encoder_out, - beam=params.beam_size, + num_active_paths=params.num_active_paths, ) else: raise ValueError( @@ -421,7 +421,7 @@ def decode_dataset( f"max_states_{params.max_states}" ) elif params.decoding_method == "modified_beam_search": - key = f"beam_size_{params.beam_size}" + key = f"num_active_paths_{params.num_active_paths}" else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py index 171f17f03..de89d41c2 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py @@ -17,13 +17,13 @@ """ Usage: -./pruned_transducer_stateless2/streaming_decode.py \ +./pruned_transducer_stateless4/streaming_decode.py \ --epoch 28 \ --avg 15 \ --left-context 32 \ --decode-chunk-size 8 \ --right-context 0 \ - --exp-dir ./pruned_transducer_stateless2/exp \ + --exp-dir ./pruned_transducer_stateless4/exp \ --decoding_method greedy_search \ --num-decode-streams 200 """ @@ -138,7 +138,7 @@ def get_parser(): ) parser.add_argument( - "--beam-size", + "--num_active_paths", type=int, default=4, help="""An interger indicating how many candidates we will keep for each @@ -303,7 +303,7 @@ def decode_one_chunk( model=model, streams=decode_streams, encoder_out=encoder_out, - beam=params.beam_size, + num_active_paths=params.num_active_paths, ) else: raise ValueError( @@ -433,7 +433,7 @@ def decode_dataset( f"max_states_{params.max_states}" ) elif params.decoding_method == "modified_beam_search": - key = f"beam_size_{params.beam_size}" + key = f"num_active_paths_{params.num_active_paths}" else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}"