From 32de2766d591d2e1a77c06a40d2861fb1bbcd3ad Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Sat, 5 Nov 2022 22:36:06 +0800 Subject: [PATCH] Refactor getting timestamps in fsa-based decoding (#660) * refactor getting timestamps for fsa-based decoding * fix doc * fix bug --- .../ASR/lstm_transducer_stateless3/decode.py | 2 +- .../beam_search.py | 10 +-- .../pruned_transducer_stateless4/decode.py | 2 +- icefall/utils.py | 74 +++++++++---------- 4 files changed, 41 insertions(+), 47 deletions(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py index 052d027e3..9eee19379 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py @@ -487,7 +487,7 @@ def decode_one_batch( ) tokens.extend(res.tokens) timestamps.extend(res.timestamps) - res = DecodingResults(tokens=tokens, timestamps=timestamps) + res = DecodingResults(hyps=tokens, timestamps=timestamps) hyps, timestamps = parse_hyp_and_timestamp( decoding_method=params.decoding_method, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index b1fd75204..a3fa6cc7c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -598,7 +598,7 @@ def greedy_search( return hyp else: return DecodingResults( - tokens=[hyp], + hyps=[hyp], timestamps=[timestamp], ) @@ -712,7 +712,7 @@ def greedy_search_batch( return ans else: return DecodingResults( - tokens=ans, + hyps=ans, timestamps=ans_timestamps, ) @@ -1049,7 +1049,7 @@ def modified_beam_search( return ans else: return DecodingResults( - tokens=ans, + hyps=ans, timestamps=ans_timestamps, ) @@ -1176,7 +1176,7 @@ def _deprecated_modified_beam_search( if not return_timestamps: return ys else: - return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp]) + return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp]) def beam_search( @@ -1336,7 +1336,7 @@ def beam_search( if not return_timestamps: return ys else: - return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp]) + return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp]) def fast_beam_search_with_nbest_rescoring( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 7003e4764..4f043e5a6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -531,7 +531,7 @@ def decode_one_batch( ) tokens.extend(res.tokens) timestamps.extend(res.timestamps) - res = DecodingResults(tokens=tokens, timestamps=timestamps) + res = DecodingResults(hyps=tokens, timestamps=timestamps) hyps, timestamps = parse_hyp_and_timestamp( decoding_method=params.decoding_method, diff --git a/icefall/utils.py b/icefall/utils.py index 93dd0b967..e83fccdde 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -251,33 +251,20 @@ def get_texts( @dataclass class DecodingResults: - # Decoded token IDs for each utterance in the batch - tokens: List[List[int]] - # timestamps[i][k] contains the frame number on which tokens[i][k] # is decoded timestamps: List[List[int]] - # hyps[i] is the recognition results, i.e., word IDs + # hyps[i] is the recognition results, i.e., word IDs or token IDs # for the i-th utterance with fast_beam_search_nbest_LG. - hyps: Union[List[List[int]], k2.RaggedTensor] = None - - -def get_tokens_and_timestamps(labels: List[int]) -> Tuple[List[int], List[int]]: - tokens = [] - timestamps = [] - for i, v in enumerate(labels): - if v != 0: - tokens.append(v) - timestamps.append(i) - - return tokens, timestamps + hyps: Union[List[List[int]], k2.RaggedTensor] def get_texts_with_timestamp( best_paths: k2.Fsa, return_ragged: bool = False ) -> DecodingResults: - """Extract the texts (as word IDs) and timestamps from the best-path FSAs. + """Extract the texts (as word IDs) and timestamps (as frame indexes) + from the best-path FSAs. Args: best_paths: A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. @@ -292,11 +279,18 @@ def get_texts_with_timestamp( decoded. """ if isinstance(best_paths.aux_labels, k2.RaggedTensor): + all_aux_shape = ( + best_paths.arcs.shape() + .remove_axis(1) + .compose(best_paths.aux_labels.shape) + ) + all_aux_labels = k2.RaggedTensor( + all_aux_shape, best_paths.aux_labels.values + ) # remove 0's and -1's. aux_labels = best_paths.aux_labels.remove_values_leq(0) # TODO: change arcs.shape() to arcs.shape aux_shape = best_paths.arcs.shape().compose(aux_labels.shape) - # remove the states and arcs axes. aux_shape = aux_shape.remove_axis(1) aux_shape = aux_shape.remove_axis(1) @@ -304,26 +298,26 @@ def get_texts_with_timestamp( else: # remove axis corresponding to states. aux_shape = best_paths.arcs.shape().remove_axis(1) - aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels) + all_aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels) # remove 0's and -1's. - aux_labels = aux_labels.remove_values_leq(0) + aux_labels = all_aux_labels.remove_values_leq(0) assert aux_labels.num_axes == 2 - labels_shape = best_paths.arcs.shape().remove_axis(1) - labels_list = k2.RaggedTensor( - labels_shape, best_paths.labels.contiguous() - ).tolist() - - tokens = [] timestamps = [] - for labels in labels_list: - token, time = get_tokens_and_timestamps(labels[:-1]) - tokens.append(token) - timestamps.append(time) + if isinstance(best_paths.aux_labels, k2.RaggedTensor): + for p in range(all_aux_labels.dim0): + time = [] + for i, arc in enumerate(all_aux_labels[p].tolist()): + if len(arc) == 1 and arc[0] > 0: + time.append(i) + timestamps.append(time) + else: + for labels in all_aux_labels.tolist(): + time = [i for i, v in enumerate(labels) if v > 0] + timestamps.append(time) return DecodingResults( - tokens=tokens, timestamps=timestamps, hyps=aux_labels if return_ragged else aux_labels.tolist(), ) @@ -1399,8 +1393,8 @@ def parse_hyp_and_timestamp( hyps = [] timestamps = [] - N = len(res.tokens) - assert len(res.timestamps) == N + N = len(res.hyps) + assert len(res.timestamps) == N, (len(res.timestamps), N) use_word_table = False if ( decoding_method == "fast_beam_search_nbest_LG" @@ -1410,16 +1404,16 @@ def parse_hyp_and_timestamp( use_word_table = True for i in range(N): - tokens = sp.id_to_piece(res.tokens[i]) - if use_word_table: - words = [word_table[i] for i in res.hyps[i]] - else: - words = sp.decode_pieces(tokens).split() time = convert_timestamp( res.timestamps[i], subsampling_factor, frame_shift_ms ) - time = parse_timestamp(tokens, time) - assert len(time) == len(words), (tokens, words) + if use_word_table: + words = [word_table[i] for i in res.hyps[i]] + else: + tokens = sp.id_to_piece(res.hyps[i]) + words = sp.decode_pieces(tokens).split() + time = parse_timestamp(tokens, time) + assert len(time) == len(words), (len(time), len(words)) hyps.append(words) timestamps.append(time)