Minor fixes.

This commit is contained in:
Fangjun Kuang 2022-03-29 16:10:19 +08:00
parent 52f1f6775d
commit 7c5249fb88
2 changed files with 24 additions and 12 deletions

View File

@ -41,18 +41,18 @@ def fast_beam_search(
model: model:
An instance of `Transducer`. An instance of `Transducer`.
decoding_graph: decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG. Decoding graph used for decoding, may be a TrivialGraph or an HLG.
encoder_out: encoder_out:
A tensor of shape (N, T, C) from the encoder. A tensor of shape (N, T, C) from the encoder.
encoder_out_lens: encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out` A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding. before padding.
beam: beam:
Beam value, similar to the beam used in Kaldi.. Beam value, similar to the beam used in Kaldi.
max_states: max_states:
Max states per stream per frame. Max states per stream per frame.
max_contexts: max_contexts:
Max contexts pre stream per frame. Max contexts per stream per frame.
use_max: use_max:
True to use max operation to select the hypothesis with the largest True to use max operation to select the hypothesis with the largest
log_prob when there are duplicate hypotheses; False to use log-add. log_prob when there are duplicate hypotheses; False to use log-add.
@ -91,9 +91,8 @@ def fast_beam_search(
# current_encoder_out is of shape # current_encoder_out is of shape
# (shape.NumElements(), 1, encoder_out_dim) # (shape.NumElements(), 1, encoder_out_dim)
# fmt: off # fmt: off
current_encoder_out = torch.index_select( current_encoder_out = torch.index_select(encoder_out[:, t:t + 1, :], 0,
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1) shape.row_ids(1))
)
# fmt: on # fmt: on
logits = model.joiner( logits = model.joiner(
current_encoder_out, current_encoder_out,
@ -105,13 +104,14 @@ def fast_beam_search(
decoding_streams.advance(log_probs) decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams() decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist()) lattice = decoding_streams.format_output(encoder_out_lens.tolist())
# Note: lattice is actually is an FSA if the graph is a k2.trivial_graph()
if use_max: if use_max:
best_path = one_best_decoding(lattice) best_path = one_best_decoding(lattice)
hyps = get_texts(best_path) hyps = get_texts(best_path)
return hyps return hyps
else: else:
num_paths = 20 num_paths = 200
use_double_scores = True use_double_scores = True
nbest_scale = 0.8 nbest_scale = 0.8
@ -127,6 +127,14 @@ def fast_beam_search(
# delete token IDs as it is not needed # delete token IDs as it is not needed
del word_fsa.aux_labels del word_fsa.aux_labels
word_fsa.scores.zero_() word_fsa.scores.zero_()
# remove the state axis: [fsa][state][arc] -> [fsa][arc]
word_fsa_shape = word_fsa.arcs.shape().remove_axis(1)
num_arcs = (
word_fsa_shape.row_splits(1)[1:] - word_fsa_shape.row_splits(1)[:-1]
)
num_tokens_per_path = num_arcs - 1 # minus one due to the final arc
word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops( word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops(
word_fsa word_fsa
) )
@ -159,8 +167,9 @@ def fast_beam_search(
path_lattice = k2.top_sort(k2.connect(path_lattice)) path_lattice = k2.top_sort(k2.connect(path_lattice))
tot_scores = path_lattice.get_tot_scores( tot_scores = path_lattice.get_tot_scores(
use_double_scores=True, log_semiring=True use_double_scores=use_double_scores, log_semiring=True
) )
tot_scores = tot_scores / num_tokens_per_path
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
best_hyp_indexes = ragged_tot_scores.argmax() best_hyp_indexes = ragged_tot_scores.argmax()
@ -223,7 +232,7 @@ def greedy_search(
continue continue
# fmt: off # fmt: off
current_encoder_out = encoder_out[:, t:t+1, :] current_encoder_out = encoder_out[:, t:t + 1, :]
# fmt: on # fmt: on
logits = model.joiner( logits = model.joiner(
current_encoder_out, decoder_out, encoder_out_len, decoder_out_len current_encoder_out, decoder_out, encoder_out_len, decoder_out_len
@ -739,7 +748,7 @@ def _deprecated_modified_beam_search(
for t in range(T): for t in range(T):
# fmt: off # fmt: off
current_encoder_out = encoder_out[:, t:t+1, :] current_encoder_out = encoder_out[:, t:t + 1, :]
# current_encoder_out is of shape (1, 1, encoder_out_dim) # current_encoder_out is of shape (1, 1, encoder_out_dim)
# fmt: on # fmt: on
A = list(B) A = list(B)
@ -861,7 +870,7 @@ def beam_search(
while t < T and sym_per_utt < max_sym_per_utt: while t < T and sym_per_utt < max_sym_per_utt:
# fmt: off # fmt: off
current_encoder_out = encoder_out[:, t:t+1, :] current_encoder_out = encoder_out[:, t:t + 1, :]
# fmt: on # fmt: on
A = B A = B
B = HypothesisList() B = HypothesisList()

View File

@ -139,7 +139,10 @@ def get_parser():
type=int, type=int,
default=4, default=4,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is
beam_search or modified_beam_search""", beam_search or modified_beam_search.
It specifies the number of active hypotheses to keep at each
time step.
""",
) )
parser.add_argument( parser.add_argument(