Minor fixes.

This commit is contained in:
Fangjun Kuang 2022-03-29 16:10:19 +08:00
parent 52f1f6775d
commit 7c5249fb88
2 changed files with 24 additions and 12 deletions

View File

@ -41,18 +41,18 @@ def fast_beam_search(
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
Decoding graph used for decoding, may be a TrivialGraph or an HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
Beam value, similar to the beam used in Kaldi.
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
Max contexts per stream per frame.
use_max:
True to use max operation to select the hypothesis with the largest
log_prob when there are duplicate hypotheses; False to use log-add.
@ -91,9 +91,8 @@ def fast_beam_search(
# current_encoder_out is of shape
# (shape.NumElements(), 1, encoder_out_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1)
)
current_encoder_out = torch.index_select(encoder_out[:, t:t + 1, :], 0,
shape.row_ids(1))
# fmt: on
logits = model.joiner(
current_encoder_out,
@ -105,13 +104,14 @@ def fast_beam_search(
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
# Note: lattice is actually is an FSA if the graph is a k2.trivial_graph()
if use_max:
best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
return hyps
else:
num_paths = 20
num_paths = 200
use_double_scores = True
nbest_scale = 0.8
@ -127,6 +127,14 @@ def fast_beam_search(
# delete token IDs as it is not needed
del word_fsa.aux_labels
word_fsa.scores.zero_()
# remove the state axis: [fsa][state][arc] -> [fsa][arc]
word_fsa_shape = word_fsa.arcs.shape().remove_axis(1)
num_arcs = (
word_fsa_shape.row_splits(1)[1:] - word_fsa_shape.row_splits(1)[:-1]
)
num_tokens_per_path = num_arcs - 1 # minus one due to the final arc
word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops(
word_fsa
)
@ -159,8 +167,9 @@ def fast_beam_search(
path_lattice = k2.top_sort(k2.connect(path_lattice))
tot_scores = path_lattice.get_tot_scores(
use_double_scores=True, log_semiring=True
use_double_scores=use_double_scores, log_semiring=True
)
tot_scores = tot_scores / num_tokens_per_path
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
best_hyp_indexes = ragged_tot_scores.argmax()
@ -223,7 +232,7 @@ def greedy_search(
continue
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
current_encoder_out = encoder_out[:, t:t + 1, :]
# fmt: on
logits = model.joiner(
current_encoder_out, decoder_out, encoder_out_len, decoder_out_len
@ -739,7 +748,7 @@ def _deprecated_modified_beam_search(
for t in range(T):
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
current_encoder_out = encoder_out[:, t:t + 1, :]
# current_encoder_out is of shape (1, 1, encoder_out_dim)
# fmt: on
A = list(B)
@ -861,7 +870,7 @@ def beam_search(
while t < T and sym_per_utt < max_sym_per_utt:
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
current_encoder_out = encoder_out[:, t:t + 1, :]
# fmt: on
A = B
B = HypothesisList()

View File

@ -139,7 +139,10 @@ def get_parser():
type=int,
default=4,
help="""Used only when --decoding-method is
beam_search or modified_beam_search""",
beam_search or modified_beam_search.
It specifies the number of active hypotheses to keep at each
time step.
""",
)
parser.add_argument(