diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index 27ce9ddde..635b13f4d 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -41,18 +41,18 @@ def fast_beam_search( model: An instance of `Transducer`. decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a HLG. + Decoding graph used for decoding, may be a TrivialGraph or an HLG. encoder_out: A tensor of shape (N, T, C) from the encoder. encoder_out_lens: A tensor of shape (N,) containing the number of frames in `encoder_out` before padding. beam: - Beam value, similar to the beam used in Kaldi.. + Beam value, similar to the beam used in Kaldi. max_states: Max states per stream per frame. max_contexts: - Max contexts pre stream per frame. + Max contexts per stream per frame. use_max: True to use max operation to select the hypothesis with the largest log_prob when there are duplicate hypotheses; False to use log-add. @@ -91,9 +91,8 @@ def fast_beam_search( # current_encoder_out is of shape # (shape.NumElements(), 1, encoder_out_dim) # fmt: off - current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1) - ) + current_encoder_out = torch.index_select(encoder_out[:, t:t + 1, :], 0, + shape.row_ids(1)) # fmt: on logits = model.joiner( current_encoder_out, @@ -105,13 +104,14 @@ def fast_beam_search( decoding_streams.advance(log_probs) decoding_streams.terminate_and_flush_to_streams() lattice = decoding_streams.format_output(encoder_out_lens.tolist()) + # Note: lattice is actually is an FSA if the graph is a k2.trivial_graph() if use_max: best_path = one_best_decoding(lattice) hyps = get_texts(best_path) return hyps else: - num_paths = 20 + num_paths = 200 use_double_scores = True nbest_scale = 0.8 @@ -127,6 +127,14 @@ def fast_beam_search( # delete token IDs as it is not needed del word_fsa.aux_labels word_fsa.scores.zero_() + + # remove the state axis: [fsa][state][arc] -> [fsa][arc] + word_fsa_shape = word_fsa.arcs.shape().remove_axis(1) + num_arcs = ( + word_fsa_shape.row_splits(1)[1:] - word_fsa_shape.row_splits(1)[:-1] + ) + num_tokens_per_path = num_arcs - 1 # minus one due to the final arc + word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops( word_fsa ) @@ -159,8 +167,9 @@ def fast_beam_search( path_lattice = k2.top_sort(k2.connect(path_lattice)) tot_scores = path_lattice.get_tot_scores( - use_double_scores=True, log_semiring=True + use_double_scores=use_double_scores, log_semiring=True ) + tot_scores = tot_scores / num_tokens_per_path ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) best_hyp_indexes = ragged_tot_scores.argmax() @@ -223,7 +232,7 @@ def greedy_search( continue # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :] + current_encoder_out = encoder_out[:, t:t + 1, :] # fmt: on logits = model.joiner( current_encoder_out, decoder_out, encoder_out_len, decoder_out_len @@ -739,7 +748,7 @@ def _deprecated_modified_beam_search( for t in range(T): # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :] + current_encoder_out = encoder_out[:, t:t + 1, :] # current_encoder_out is of shape (1, 1, encoder_out_dim) # fmt: on A = list(B) @@ -861,7 +870,7 @@ def beam_search( while t < T and sym_per_utt < max_sym_per_utt: # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :] + current_encoder_out = encoder_out[:, t:t + 1, :] # fmt: on A = B B = HypothesisList() diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index 7be52b183..c126f4694 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -139,7 +139,10 @@ def get_parser(): type=int, default=4, help="""Used only when --decoding-method is - beam_search or modified_beam_search""", + beam_search or modified_beam_search. + It specifies the number of active hypotheses to keep at each + time step. + """, ) parser.add_argument(