mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 11:32:19 +00:00
parse timestamps (frame indexes) and texts for other cases
This commit is contained in:
parent
1a6a035d87
commit
00e6a8dd78
@ -1510,7 +1510,8 @@ def parse_bpe_timestamps_and_texts(
|
||||
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
||||
containing multiple FSAs, which is expected to be the result
|
||||
of k2.shortest_path (otherwise the returned values won't
|
||||
be meaningful).
|
||||
be meaningful). Its attribtutes `labels` and `aux_labels`
|
||||
are both BPE tokens.
|
||||
sp:
|
||||
The BPE model.
|
||||
|
||||
@ -1557,3 +1558,72 @@ def parse_bpe_timestamps_and_texts(
|
||||
utt_words.append(words)
|
||||
|
||||
return utt_index_pairs, utt_words
|
||||
|
||||
|
||||
def parse_timestamps_and_texts(
|
||||
best_paths: k2.Fsa, word_table: k2.SymbolTable
|
||||
) -> Tuple[List[Tuple[int, int]], List[List[str]]]:
|
||||
"""Parse timestamps (frame indexes) and texts.
|
||||
|
||||
Args:
|
||||
best_paths:
|
||||
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
||||
containing multiple FSAs, which is expected to be the result
|
||||
of k2.shortest_path (otherwise the returned values won't
|
||||
be meaningful). Attribtute `labels` is the prediction unit,
|
||||
e.g., phone or BPE tokens. Attribute `aux_labels` is the word index.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
|
||||
Returns:
|
||||
utt_index_pairs:
|
||||
A list of pair list. utt_index_pairs[i] is a list of
|
||||
(start-frame-index, end-frame-index) pairs for each word in
|
||||
utterance-i.
|
||||
utt_words:
|
||||
A list of str list. utt_words[i] is a word list of utterence-i.
|
||||
"""
|
||||
# [utt][words]
|
||||
word_ids = get_texts(best_paths)
|
||||
|
||||
shape = best_paths.arcs.shape().remove_axis(1)
|
||||
|
||||
# labels: [utt][arcs]
|
||||
labels = k2.RaggedTensor(shape, best_paths.labels.contiguous())
|
||||
# remove -1's.
|
||||
labels = labels.remove_values_eq(-1)
|
||||
labels = labels.tolist()
|
||||
|
||||
# aux_labels: [utt][arcs]
|
||||
aux_shape = shape.compose(best_paths.aux_labels.shape)
|
||||
aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels.values.contiguous())
|
||||
aux_labels = aux_labels.tolist()
|
||||
|
||||
utt_index_pairs = []
|
||||
utt_words = []
|
||||
for i, (label, aux_label) in enumerate(zip(labels, aux_labels)):
|
||||
num_arcs = len(label)
|
||||
# The last arc of aux_label is the arc entering the final state
|
||||
assert num_arcs == len(aux_label) - 1, (num_arcs, len(aux_label))
|
||||
|
||||
index_pairs = []
|
||||
start = -1
|
||||
end = -1
|
||||
for arc in range(num_arcs):
|
||||
# len(aux_label[arc]) is 0 or 1
|
||||
if label[arc] != 0 and len(aux_label[arc]) != 0:
|
||||
if start != -1 and end != -1:
|
||||
index_pairs.append((start, end))
|
||||
start = arc
|
||||
if label[arc] != 0:
|
||||
end = arc
|
||||
if start != -1 and end != -1:
|
||||
index_pairs.append((start, end))
|
||||
|
||||
words = [word_table[w] for w in word_ids[i]]
|
||||
assert len(index_pairs) == len(words), (len(index_pairs), len(words))
|
||||
|
||||
utt_index_pairs.append(index_pairs)
|
||||
utt_words.append(words)
|
||||
|
||||
return utt_index_pairs, utt_words
|
||||
|
Loading…
x
Reference in New Issue
Block a user