diff --git a/icefall/utils.py b/icefall/utils.py index 0ea908cbc..9dc718043 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1510,7 +1510,8 @@ def parse_bpe_timestamps_and_texts( A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. containing multiple FSAs, which is expected to be the result of k2.shortest_path (otherwise the returned values won't - be meaningful). + be meaningful). Its attribtutes `labels` and `aux_labels` + are both BPE tokens. sp: The BPE model. @@ -1557,3 +1558,72 @@ def parse_bpe_timestamps_and_texts( utt_words.append(words) return utt_index_pairs, utt_words + + +def parse_timestamps_and_texts( + best_paths: k2.Fsa, word_table: k2.SymbolTable +) -> Tuple[List[Tuple[int, int]], List[List[str]]]: + """Parse timestamps (frame indexes) and texts. + + Args: + best_paths: + A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. + containing multiple FSAs, which is expected to be the result + of k2.shortest_path (otherwise the returned values won't + be meaningful). Attribtute `labels` is the prediction unit, + e.g., phone or BPE tokens. Attribute `aux_labels` is the word index. + word_table: + The word symbol table. + + Returns: + utt_index_pairs: + A list of pair list. utt_index_pairs[i] is a list of + (start-frame-index, end-frame-index) pairs for each word in + utterance-i. + utt_words: + A list of str list. utt_words[i] is a word list of utterence-i. + """ + # [utt][words] + word_ids = get_texts(best_paths) + + shape = best_paths.arcs.shape().remove_axis(1) + + # labels: [utt][arcs] + labels = k2.RaggedTensor(shape, best_paths.labels.contiguous()) + # remove -1's. + labels = labels.remove_values_eq(-1) + labels = labels.tolist() + + # aux_labels: [utt][arcs] + aux_shape = shape.compose(best_paths.aux_labels.shape) + aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels.values.contiguous()) + aux_labels = aux_labels.tolist() + + utt_index_pairs = [] + utt_words = [] + for i, (label, aux_label) in enumerate(zip(labels, aux_labels)): + num_arcs = len(label) + # The last arc of aux_label is the arc entering the final state + assert num_arcs == len(aux_label) - 1, (num_arcs, len(aux_label)) + + index_pairs = [] + start = -1 + end = -1 + for arc in range(num_arcs): + # len(aux_label[arc]) is 0 or 1 + if label[arc] != 0 and len(aux_label[arc]) != 0: + if start != -1 and end != -1: + index_pairs.append((start, end)) + start = arc + if label[arc] != 0: + end = arc + if start != -1 and end != -1: + index_pairs.append((start, end)) + + words = [word_table[w] for w in word_ids[i]] + assert len(index_pairs) == len(words), (len(index_pairs), len(words)) + + utt_index_pairs.append(index_pairs) + utt_words.append(words) + + return utt_index_pairs, utt_words