mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 06:04:18 +00:00
Minor fixes.
This commit is contained in:
parent
52f1f6775d
commit
7c5249fb88
@ -41,18 +41,18 @@ def fast_beam_search(
|
||||
model:
|
||||
An instance of `Transducer`.
|
||||
decoding_graph:
|
||||
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
||||
Decoding graph used for decoding, may be a TrivialGraph or an HLG.
|
||||
encoder_out:
|
||||
A tensor of shape (N, T, C) from the encoder.
|
||||
encoder_out_lens:
|
||||
A tensor of shape (N,) containing the number of frames in `encoder_out`
|
||||
before padding.
|
||||
beam:
|
||||
Beam value, similar to the beam used in Kaldi..
|
||||
Beam value, similar to the beam used in Kaldi.
|
||||
max_states:
|
||||
Max states per stream per frame.
|
||||
max_contexts:
|
||||
Max contexts pre stream per frame.
|
||||
Max contexts per stream per frame.
|
||||
use_max:
|
||||
True to use max operation to select the hypothesis with the largest
|
||||
log_prob when there are duplicate hypotheses; False to use log-add.
|
||||
@ -91,9 +91,8 @@ def fast_beam_search(
|
||||
# current_encoder_out is of shape
|
||||
# (shape.NumElements(), 1, encoder_out_dim)
|
||||
# fmt: off
|
||||
current_encoder_out = torch.index_select(
|
||||
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1)
|
||||
)
|
||||
current_encoder_out = torch.index_select(encoder_out[:, t:t + 1, :], 0,
|
||||
shape.row_ids(1))
|
||||
# fmt: on
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
@ -105,13 +104,14 @@ def fast_beam_search(
|
||||
decoding_streams.advance(log_probs)
|
||||
decoding_streams.terminate_and_flush_to_streams()
|
||||
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
|
||||
# Note: lattice is actually is an FSA if the graph is a k2.trivial_graph()
|
||||
|
||||
if use_max:
|
||||
best_path = one_best_decoding(lattice)
|
||||
hyps = get_texts(best_path)
|
||||
return hyps
|
||||
else:
|
||||
num_paths = 20
|
||||
num_paths = 200
|
||||
use_double_scores = True
|
||||
nbest_scale = 0.8
|
||||
|
||||
@ -127,6 +127,14 @@ def fast_beam_search(
|
||||
# delete token IDs as it is not needed
|
||||
del word_fsa.aux_labels
|
||||
word_fsa.scores.zero_()
|
||||
|
||||
# remove the state axis: [fsa][state][arc] -> [fsa][arc]
|
||||
word_fsa_shape = word_fsa.arcs.shape().remove_axis(1)
|
||||
num_arcs = (
|
||||
word_fsa_shape.row_splits(1)[1:] - word_fsa_shape.row_splits(1)[:-1]
|
||||
)
|
||||
num_tokens_per_path = num_arcs - 1 # minus one due to the final arc
|
||||
|
||||
word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops(
|
||||
word_fsa
|
||||
)
|
||||
@ -159,8 +167,9 @@ def fast_beam_search(
|
||||
path_lattice = k2.top_sort(k2.connect(path_lattice))
|
||||
|
||||
tot_scores = path_lattice.get_tot_scores(
|
||||
use_double_scores=True, log_semiring=True
|
||||
use_double_scores=use_double_scores, log_semiring=True
|
||||
)
|
||||
tot_scores = tot_scores / num_tokens_per_path
|
||||
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||
best_hyp_indexes = ragged_tot_scores.argmax()
|
||||
|
||||
@ -223,7 +232,7 @@ def greedy_search(
|
||||
continue
|
||||
|
||||
# fmt: off
|
||||
current_encoder_out = encoder_out[:, t:t+1, :]
|
||||
current_encoder_out = encoder_out[:, t:t + 1, :]
|
||||
# fmt: on
|
||||
logits = model.joiner(
|
||||
current_encoder_out, decoder_out, encoder_out_len, decoder_out_len
|
||||
@ -739,7 +748,7 @@ def _deprecated_modified_beam_search(
|
||||
|
||||
for t in range(T):
|
||||
# fmt: off
|
||||
current_encoder_out = encoder_out[:, t:t+1, :]
|
||||
current_encoder_out = encoder_out[:, t:t + 1, :]
|
||||
# current_encoder_out is of shape (1, 1, encoder_out_dim)
|
||||
# fmt: on
|
||||
A = list(B)
|
||||
@ -861,7 +870,7 @@ def beam_search(
|
||||
|
||||
while t < T and sym_per_utt < max_sym_per_utt:
|
||||
# fmt: off
|
||||
current_encoder_out = encoder_out[:, t:t+1, :]
|
||||
current_encoder_out = encoder_out[:, t:t + 1, :]
|
||||
# fmt: on
|
||||
A = B
|
||||
B = HypothesisList()
|
||||
|
@ -139,7 +139,10 @@ def get_parser():
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --decoding-method is
|
||||
beam_search or modified_beam_search""",
|
||||
beam_search or modified_beam_search.
|
||||
It specifies the number of active hypotheses to keep at each
|
||||
time step.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
Loading…
x
Reference in New Issue
Block a user