mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-14 04:22:21 +00:00
parse timestamps and texts for BPE-based models
This commit is contained in:
parent
dd0047e605
commit
1a6a035d87
126
icefall/utils.py
126
icefall/utils.py
@ -1431,3 +1431,129 @@ def filter_uneven_sized_batch(batch: dict, allowed_max_frames: int):
|
|||||||
batch["supervisions"][k] = v[:keep_num_utt]
|
batch["supervisions"][k] = v[:keep_num_utt]
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def parse_bpe_start_end_pairs(
|
||||||
|
tokens: List[str], is_first_token: List[bool]
|
||||||
|
) -> List[Tuple[int, int]]:
|
||||||
|
"""Parse pairs of start and end frame indexes for each word,
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens:
|
||||||
|
List of BPE tokens.
|
||||||
|
is_first_token:
|
||||||
|
List of bool values, which indicates whether it is the first token,
|
||||||
|
i.e., not repeat or blank.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (start-frame-index, end-frame-index) pairs for each word.
|
||||||
|
"""
|
||||||
|
assert len(tokens) == len(is_first_token), (len(tokens), len(is_first_token))
|
||||||
|
|
||||||
|
start_token = b"\xe2\x96\x81".decode() # '_'
|
||||||
|
blank_token = "<blk>"
|
||||||
|
|
||||||
|
non_blank_idx = [i for i in range(len(tokens)) if tokens[i] != blank_token]
|
||||||
|
num_non_blank = len(non_blank_idx)
|
||||||
|
|
||||||
|
pairs = []
|
||||||
|
start = -1
|
||||||
|
end = -1
|
||||||
|
for j in range(num_non_blank):
|
||||||
|
# The index in all frames
|
||||||
|
i = non_blank_idx[j]
|
||||||
|
|
||||||
|
found_start = False
|
||||||
|
if is_first_token[i] and (j == 0 or tokens[i].startswith(start_token)):
|
||||||
|
found_start = True
|
||||||
|
if tokens[i] == start_token:
|
||||||
|
if j == num_non_blank - 1:
|
||||||
|
# It is the last non-blank token
|
||||||
|
found_start = False
|
||||||
|
elif is_first_token[non_blank_idx[j + 1]] and tokens[
|
||||||
|
non_blank_idx[j + 1]
|
||||||
|
].startswith(start_token):
|
||||||
|
# The next not-blank token is a first-token and also starts with start_token
|
||||||
|
found_start = False
|
||||||
|
if found_start:
|
||||||
|
start = i
|
||||||
|
|
||||||
|
if start != -1:
|
||||||
|
found_end = False
|
||||||
|
if j == num_non_blank - 1:
|
||||||
|
# It is the last non-blank token
|
||||||
|
found_end = True
|
||||||
|
elif is_first_token[non_blank_idx[j + 1]] and tokens[
|
||||||
|
non_blank_idx[j + 1]
|
||||||
|
].startswith(start_token):
|
||||||
|
# The next not-blank token is a first-token and also starts with start_token
|
||||||
|
found_end = True
|
||||||
|
if found_end:
|
||||||
|
end = i
|
||||||
|
|
||||||
|
if start != -1 and end != -1:
|
||||||
|
pairs.append((start, end))
|
||||||
|
# Reset start and end
|
||||||
|
start = -1
|
||||||
|
end = -1
|
||||||
|
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
|
||||||
|
def parse_bpe_timestamps_and_texts(
|
||||||
|
best_paths: k2.Fsa, sp: spm.SentencePieceProcessor
|
||||||
|
) -> Tuple[List[Tuple[int, int]], List[List[str]]]:
|
||||||
|
"""Parse timestamps (frame indexes) and texts for BPE-based models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
best_paths:
|
||||||
|
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
||||||
|
containing multiple FSAs, which is expected to be the result
|
||||||
|
of k2.shortest_path (otherwise the returned values won't
|
||||||
|
be meaningful).
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
utt_index_pairs:
|
||||||
|
A list of pair list. utt_index_pairs[i] is a list of
|
||||||
|
(start-frame-index, end-frame-index) pairs for each word in
|
||||||
|
utterance-i.
|
||||||
|
utt_words:
|
||||||
|
A list of str list. utt_words[i] is a word list of utterence-i.
|
||||||
|
"""
|
||||||
|
shape = best_paths.arcs.shape().remove_axis(1)
|
||||||
|
|
||||||
|
# labels: [utt][arcs]
|
||||||
|
labels = k2.RaggedTensor(shape, best_paths.labels.contiguous())
|
||||||
|
# remove -1's.
|
||||||
|
labels = labels.remove_values_eq(-1)
|
||||||
|
labels = labels.tolist()
|
||||||
|
|
||||||
|
# aux_labels: [utt][arcs]
|
||||||
|
aux_labels = k2.RaggedTensor(shape, best_paths.aux_labels.contiguous())
|
||||||
|
|
||||||
|
# remove -1's.
|
||||||
|
all_aux_labels = aux_labels.remove_values_eq(-1)
|
||||||
|
# len(all_aux_labels[i]) is equal to the number of frames
|
||||||
|
all_aux_labels = all_aux_labels.tolist()
|
||||||
|
|
||||||
|
# remove 0's and -1's.
|
||||||
|
out_aux_labels = aux_labels.remove_values_leq(0)
|
||||||
|
# len(out_aux_labels[i]) is equal to the number of output BPE tokens
|
||||||
|
out_aux_labels = out_aux_labels.tolist()
|
||||||
|
|
||||||
|
utt_index_pairs = []
|
||||||
|
utt_words = []
|
||||||
|
for i in range(len(labels)):
|
||||||
|
tokens = sp.id_to_piece(labels[i])
|
||||||
|
words = sp.decode(out_aux_labels[i]).split()
|
||||||
|
|
||||||
|
# Indicates whether it is the first token, i.e., not-repeat and not-blank.
|
||||||
|
is_first_token = [a != 0 for a in all_aux_labels[i]]
|
||||||
|
index_pairs = parse_bpe_start_end_pairs(tokens, is_first_token)
|
||||||
|
assert len(index_pairs) == len(words), (len(index_pairs), len(words))
|
||||||
|
utt_index_pairs.append(index_pairs)
|
||||||
|
utt_words.append(words)
|
||||||
|
|
||||||
|
return utt_index_pairs, utt_words
|
||||||
|
Loading…
x
Reference in New Issue
Block a user