parse timestamps (frame indexes) and texts for other cases

This commit is contained in:
yaozengwei 2023-02-06 17:09:52 +08:00
parent 1a6a035d87
commit 00e6a8dd78

View File

@ -1510,7 +1510,8 @@ def parse_bpe_timestamps_and_texts(
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
containing multiple FSAs, which is expected to be the result
of k2.shortest_path (otherwise the returned values won't
be meaningful).
be meaningful). Its attribtutes `labels` and `aux_labels`
are both BPE tokens.
sp:
The BPE model.
@ -1557,3 +1558,72 @@ def parse_bpe_timestamps_and_texts(
utt_words.append(words)
return utt_index_pairs, utt_words
def parse_timestamps_and_texts(
best_paths: k2.Fsa, word_table: k2.SymbolTable
) -> Tuple[List[Tuple[int, int]], List[List[str]]]:
"""Parse timestamps (frame indexes) and texts.
Args:
best_paths:
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
containing multiple FSAs, which is expected to be the result
of k2.shortest_path (otherwise the returned values won't
be meaningful). Attribtute `labels` is the prediction unit,
e.g., phone or BPE tokens. Attribute `aux_labels` is the word index.
word_table:
The word symbol table.
Returns:
utt_index_pairs:
A list of pair list. utt_index_pairs[i] is a list of
(start-frame-index, end-frame-index) pairs for each word in
utterance-i.
utt_words:
A list of str list. utt_words[i] is a word list of utterence-i.
"""
# [utt][words]
word_ids = get_texts(best_paths)
shape = best_paths.arcs.shape().remove_axis(1)
# labels: [utt][arcs]
labels = k2.RaggedTensor(shape, best_paths.labels.contiguous())
# remove -1's.
labels = labels.remove_values_eq(-1)
labels = labels.tolist()
# aux_labels: [utt][arcs]
aux_shape = shape.compose(best_paths.aux_labels.shape)
aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels.values.contiguous())
aux_labels = aux_labels.tolist()
utt_index_pairs = []
utt_words = []
for i, (label, aux_label) in enumerate(zip(labels, aux_labels)):
num_arcs = len(label)
# The last arc of aux_label is the arc entering the final state
assert num_arcs == len(aux_label) - 1, (num_arcs, len(aux_label))
index_pairs = []
start = -1
end = -1
for arc in range(num_arcs):
# len(aux_label[arc]) is 0 or 1
if label[arc] != 0 and len(aux_label[arc]) != 0:
if start != -1 and end != -1:
index_pairs.append((start, end))
start = arc
if label[arc] != 0:
end = arc
if start != -1 and end != -1:
index_pairs.append((start, end))
words = [word_table[w] for w in word_ids[i]]
assert len(index_pairs) == len(words), (len(index_pairs), len(words))
utt_index_pairs.append(index_pairs)
utt_words.append(words)
return utt_index_pairs, utt_words