From 1a6a035d87323de2d83987a2164a3a90f162bbaa Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Sun, 5 Feb 2023 18:20:49 +0800 Subject: [PATCH] parse timestamps and texts for BPE-based models --- icefall/utils.py | 126 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 126 insertions(+) diff --git a/icefall/utils.py b/icefall/utils.py index ba0b7fe43..0ea908cbc 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1431,3 +1431,129 @@ def filter_uneven_sized_batch(batch: dict, allowed_max_frames: int): batch["supervisions"][k] = v[:keep_num_utt] return batch + + +def parse_bpe_start_end_pairs( + tokens: List[str], is_first_token: List[bool] +) -> List[Tuple[int, int]]: + """Parse pairs of start and end frame indexes for each word, + + Args: + tokens: + List of BPE tokens. + is_first_token: + List of bool values, which indicates whether it is the first token, + i.e., not repeat or blank. + + Returns: + List of (start-frame-index, end-frame-index) pairs for each word. + """ + assert len(tokens) == len(is_first_token), (len(tokens), len(is_first_token)) + + start_token = b"\xe2\x96\x81".decode() # '_' + blank_token = "" + + non_blank_idx = [i for i in range(len(tokens)) if tokens[i] != blank_token] + num_non_blank = len(non_blank_idx) + + pairs = [] + start = -1 + end = -1 + for j in range(num_non_blank): + # The index in all frames + i = non_blank_idx[j] + + found_start = False + if is_first_token[i] and (j == 0 or tokens[i].startswith(start_token)): + found_start = True + if tokens[i] == start_token: + if j == num_non_blank - 1: + # It is the last non-blank token + found_start = False + elif is_first_token[non_blank_idx[j + 1]] and tokens[ + non_blank_idx[j + 1] + ].startswith(start_token): + # The next not-blank token is a first-token and also starts with start_token + found_start = False + if found_start: + start = i + + if start != -1: + found_end = False + if j == num_non_blank - 1: + # It is the last non-blank token + found_end = True + elif is_first_token[non_blank_idx[j + 1]] and tokens[ + non_blank_idx[j + 1] + ].startswith(start_token): + # The next not-blank token is a first-token and also starts with start_token + found_end = True + if found_end: + end = i + + if start != -1 and end != -1: + pairs.append((start, end)) + # Reset start and end + start = -1 + end = -1 + + return pairs + + +def parse_bpe_timestamps_and_texts( + best_paths: k2.Fsa, sp: spm.SentencePieceProcessor +) -> Tuple[List[Tuple[int, int]], List[List[str]]]: + """Parse timestamps (frame indexes) and texts for BPE-based models. + + Args: + best_paths: + A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. + containing multiple FSAs, which is expected to be the result + of k2.shortest_path (otherwise the returned values won't + be meaningful). + sp: + The BPE model. + + Returns: + utt_index_pairs: + A list of pair list. utt_index_pairs[i] is a list of + (start-frame-index, end-frame-index) pairs for each word in + utterance-i. + utt_words: + A list of str list. utt_words[i] is a word list of utterence-i. + """ + shape = best_paths.arcs.shape().remove_axis(1) + + # labels: [utt][arcs] + labels = k2.RaggedTensor(shape, best_paths.labels.contiguous()) + # remove -1's. + labels = labels.remove_values_eq(-1) + labels = labels.tolist() + + # aux_labels: [utt][arcs] + aux_labels = k2.RaggedTensor(shape, best_paths.aux_labels.contiguous()) + + # remove -1's. + all_aux_labels = aux_labels.remove_values_eq(-1) + # len(all_aux_labels[i]) is equal to the number of frames + all_aux_labels = all_aux_labels.tolist() + + # remove 0's and -1's. + out_aux_labels = aux_labels.remove_values_leq(0) + # len(out_aux_labels[i]) is equal to the number of output BPE tokens + out_aux_labels = out_aux_labels.tolist() + + utt_index_pairs = [] + utt_words = [] + for i in range(len(labels)): + tokens = sp.id_to_piece(labels[i]) + words = sp.decode(out_aux_labels[i]).split() + + # Indicates whether it is the first token, i.e., not-repeat and not-blank. + is_first_token = [a != 0 for a in all_aux_labels[i]] + index_pairs = parse_bpe_start_end_pairs(tokens, is_first_token) + assert len(index_pairs) == len(words), (len(index_pairs), len(words)) + utt_index_pairs.append(index_pairs) + utt_words.append(words) + + return utt_index_pairs, utt_words