From 3e1d14b9f8f597f2a336fda2c06c55d5c9a8beb3 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Tue, 7 Feb 2023 14:59:29 +0800 Subject: [PATCH] add parse_fsa_timestamps_and_texts function, test in conformer_ctc3/decode.py --- egs/librispeech/ASR/conformer_ctc3/decode.py | 44 +++++--- icefall/utils.py | 102 ++++++++++++++++++- 2 files changed, 125 insertions(+), 21 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc3/decode.py b/egs/librispeech/ASR/conformer_ctc3/decode.py index 39186e546..33d04650f 100755 --- a/egs/librispeech/ASR/conformer_ctc3/decode.py +++ b/egs/librispeech/ASR/conformer_ctc3/decode.py @@ -96,8 +96,7 @@ from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, get_texts, - get_texts_with_timestamp, - parse_hyp_and_timestamp, + parse_fsa_timestamps_and_texts, setup_logger, store_transcripts_and_timestamps, str2bool, @@ -396,13 +395,8 @@ def decode_one_batch( best_path = one_best_decoding( lattice=lattice, use_double_scores=params.use_double_scores ) - # Note: `best_path.aux_labels` contains token IDs, not word IDs - # since we are using H, not HLG here. - # - # token_ids is a lit-of-list of IDs - res = get_texts_with_timestamp(best_path) - hyps, timestamps = parse_hyp_and_timestamp( - res=res, + timestamps, hyps = parse_fsa_timestamps_and_texts( + best_paths=best_path, sp=bpe_model, subsampling_factor=params.subsampling_factor, frame_shift_ms=params.frame_shift_ms, @@ -435,12 +429,11 @@ def decode_one_batch( lattice=lattice, use_double_scores=params.use_double_scores ) key = f"no_rescore_hlg_scale_{params.hlg_scale}" - res = get_texts_with_timestamp(best_path) - hyps, timestamps = parse_hyp_and_timestamp( - res=res, + timestamps, hyps = parse_fsa_timestamps_and_texts( + best_paths=best_path, + word_table=word_table, subsampling_factor=params.subsampling_factor, frame_shift_ms=params.frame_shift_ms, - word_table=word_table, ) else: best_path = nbest_decoding( @@ -504,7 +497,18 @@ def decode_dataset( sos_id: int, eos_id: int, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: +) -> Dict[ + str, + List[ + Tuple[ + str, + List[str], + List[str], + List[Tuple[float, float]], + List[Tuple[float, float]], + ] + ], +]: """Decode dataset. Args: @@ -555,7 +559,7 @@ def decode_dataset( time = [] if s.alignment is not None and "word" in s.alignment: time = [ - aliword.start + (aliword.start, aliword.end) for aliword in s.alignment["word"] if aliword.symbol != "" ] @@ -601,7 +605,15 @@ def save_results( test_set_name: str, results_dict: Dict[ str, - List[Tuple[List[str], List[str], List[str], List[float], List[float]]], + List[ + Tuple[ + List[str], + List[str], + List[str], + List[Tuple[float, float]], + List[Tuple[float, float]], + ] + ], ], ): test_set_wers = dict() diff --git a/icefall/utils.py b/icefall/utils.py index 8ee62949d..c89356c58 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -454,11 +454,32 @@ def store_transcripts_and_timestamps( for cut_id, ref, hyp, time_ref, time_hyp in texts: print(f"{cut_id}:\tref={ref}", file=f) print(f"{cut_id}:\thyp={hyp}", file=f) + if len(time_ref) > 0: - s = "[" + ", ".join(["%0.3f" % i for i in time_ref]) + "]" + if isinstance(time_ref[0], tuple): + # each element is pair + s = ( + "[" + + ", ".join(["(%0.3f, %.03f)" % (i, j) for (i, j) in time_ref]) + + "]" + ) + else: + # each element is a float number + s = "[" + ", ".join(["%0.3f" % i for i in time_ref]) + "]" print(f"{cut_id}:\ttimestamp_ref={s}", file=f) - s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]" - print(f"{cut_id}:\ttimestamp_hyp={s}", file=f) + + if len(time_hyp) > 0: + if isinstance(time_hyp[0], tuple): + # each element is pair + s = ( + "[" + + ", ".join(["(%0.3f, %.03f)" % (i, j) for (i, j) in time_hyp]) + + "]" + ) + else: + # each element is a float number + s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]" + print(f"{cut_id}:\ttimestamp_hyp={s}", file=f) def write_error_stats( @@ -1493,7 +1514,9 @@ def parse_bpe_start_end_pairs( end = i if start != -1 and end != -1: - pairs.append((start, end)) + if not all([tokens[t] == start_token for t in range(start, end + 1)]): + # except the case of all start_token + pairs.append((start, end)) # Reset start and end start = -1 end = -1 @@ -1554,7 +1577,7 @@ def parse_bpe_timestamps_and_texts( # Indicates whether it is the first token, i.e., not-repeat and not-blank. is_first_token = [a != 0 for a in all_aux_labels[i]] index_pairs = parse_bpe_start_end_pairs(tokens, is_first_token) - assert len(index_pairs) == len(words), (len(index_pairs), len(words)) + assert len(index_pairs) == len(words), (len(index_pairs), len(words), tokens) utt_index_pairs.append(index_pairs) utt_words.append(words) @@ -1628,3 +1651,72 @@ def parse_timestamps_and_texts( utt_words.append(words) return utt_index_pairs, utt_words + + +def parse_fsa_timestamps_and_texts( + best_paths: k2.Fsa, + sp: Optional[spm.SentencePieceProcessor] = None, + word_table: Optional[k2.SymbolTable] = None, + subsampling_factor: int = 4, + frame_shift_ms: float = 10, +) -> Tuple[List[Tuple[float, float]], List[List[str]]]: + """Parse timestamps (in seconds) and texts for given decoded fsa paths. + Currently it supports two case: + (1) ctc-decoding, the attribtutes `labels` and `aux_labels` + are both BPE tokens. In this case, sp should be provided. + (2) HLG-based 1best, the attribtute `labels` is the prediction unit, + e.g., phone or BPE tokens; attribute `aux_labels` is the word index. + In this case, word_table should be provided. + + Args: + best_paths: + A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. + containing multiple FSAs, which is expected to be the result + of k2.shortest_path (otherwise the returned values won't + be meaningful). + sp: + The BPE model. + word_table: + The word symbol table. + subsampling_factor: + The subsampling factor of the model. + frame_shift_ms: + Frame shift in milliseconds between two contiguous frames. + + Returns: + utt_time_pairs: + A list of pair list. utt_time_pairs[i] is a list of + (start-time, end-time) pairs for each word in + utterance-i. + utt_words: + A list of str list. utt_words[i] is a word list of utterence-i. + """ + if sp is not None: + assert word_table is None, "word_table is not needed if sp is provided." + utt_index_pairs, utt_words = parse_bpe_timestamps_and_texts( + best_paths=best_paths, sp=sp + ) + elif word_table is not None: + assert sp is None, "sp is not needed if word_table is provided." + utt_index_pairs, utt_words = parse_timestamps_and_texts( + best_paths=best_paths, word_table=word_table + ) + else: + raise ValueError("Either sp or word_table should be provided.") + + utt_time_pairs = [] + for utt in utt_index_pairs: + start = convert_timestamp( + frames=[i[0] for i in utt], + subsampling_factor=subsampling_factor, + frame_shift_ms=frame_shift_ms, + ) + end = convert_timestamp( + # The duration in frames is (end_frame_index - start_frame_index + 1) + frames=[i[1] + 1 for i in utt], + subsampling_factor=subsampling_factor, + frame_shift_ms=frame_shift_ms, + ) + utt_time_pairs.append(list(zip(start, end))) + + return utt_time_pairs, utt_words