Merge branch 'k2-fsa:master' into repeat-k

This commit is contained in:
Yifan Yang 2023-02-08 10:51:51 +08:00 committed by GitHub
commit e3beb93e1d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 545 additions and 39 deletions

View File

@ -96,8 +96,7 @@ from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_texts, get_texts,
get_texts_with_timestamp, parse_fsa_timestamps_and_texts,
parse_hyp_and_timestamp,
setup_logger, setup_logger,
store_transcripts_and_timestamps, store_transcripts_and_timestamps,
str2bool, str2bool,
@ -396,13 +395,8 @@ def decode_one_batch(
best_path = one_best_decoding( best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores lattice=lattice, use_double_scores=params.use_double_scores
) )
# Note: `best_path.aux_labels` contains token IDs, not word IDs timestamps, hyps = parse_fsa_timestamps_and_texts(
# since we are using H, not HLG here. best_paths=best_path,
#
# token_ids is a lit-of-list of IDs
res = get_texts_with_timestamp(best_path)
hyps, timestamps = parse_hyp_and_timestamp(
res=res,
sp=bpe_model, sp=bpe_model,
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
frame_shift_ms=params.frame_shift_ms, frame_shift_ms=params.frame_shift_ms,
@ -435,12 +429,11 @@ def decode_one_batch(
lattice=lattice, use_double_scores=params.use_double_scores lattice=lattice, use_double_scores=params.use_double_scores
) )
key = f"no_rescore_hlg_scale_{params.hlg_scale}" key = f"no_rescore_hlg_scale_{params.hlg_scale}"
res = get_texts_with_timestamp(best_path) timestamps, hyps = parse_fsa_timestamps_and_texts(
hyps, timestamps = parse_hyp_and_timestamp( best_paths=best_path,
res=res, word_table=word_table,
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
frame_shift_ms=params.frame_shift_ms, frame_shift_ms=params.frame_shift_ms,
word_table=word_table,
) )
else: else:
best_path = nbest_decoding( best_path = nbest_decoding(
@ -504,7 +497,18 @@ def decode_dataset(
sos_id: int, sos_id: int,
eos_id: int, eos_id: int,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: ) -> Dict[
str,
List[
Tuple[
str,
List[str],
List[str],
List[Tuple[float, float]],
List[Tuple[float, float]],
]
],
]:
"""Decode dataset. """Decode dataset.
Args: Args:
@ -555,7 +559,7 @@ def decode_dataset(
time = [] time = []
if s.alignment is not None and "word" in s.alignment: if s.alignment is not None and "word" in s.alignment:
time = [ time = [
aliword.start (aliword.start, aliword.end)
for aliword in s.alignment["word"] for aliword in s.alignment["word"]
if aliword.symbol != "" if aliword.symbol != ""
] ]
@ -601,7 +605,15 @@ def save_results(
test_set_name: str, test_set_name: str,
results_dict: Dict[ results_dict: Dict[
str, str,
List[Tuple[List[str], List[str], List[str], List[float], List[float]]], List[
Tuple[
List[str],
List[str],
List[str],
List[Tuple[float, float]],
List[Tuple[float, float]],
]
],
], ],
): ):
test_set_wers = dict() test_set_wers = dict()
@ -621,7 +633,11 @@ def save_results(
) )
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer, mean_delay, var_delay = write_error_stats_with_timestamps( wer, mean_delay, var_delay = write_error_stats_with_timestamps(
f, f"{test_set_name}-{key}", results, enable_log=True f,
f"{test_set_name}-{key}",
results,
enable_log=True,
with_end_time=True,
) )
test_set_wers[key] = wer test_set_wers[key] = wer
test_set_delays[key] = (mean_delay, var_delay) test_set_delays[key] = (mean_delay, var_delay)
@ -637,16 +653,17 @@ def save_results(
for key, val in test_set_wers: for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f) print("{}\t{}".format(key, val), file=f)
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0]) # sort according to the mean start symbol delay
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0][0])
delays_info = ( delays_info = (
params.res_dir params.res_dir
/ f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt" / f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(delays_info, "w") as f: with open(delays_info, "w") as f:
print("settings\tsymbol-delay", file=f) print("settings\t(start, end) symbol-delay (s) (start, end)", file=f)
for key, val in test_set_delays: for key, val in test_set_delays:
print( print(
"{}\tmean: {}s, variance: {}".format(key, val[0], val[1]), "{}\tmean: {}, variance: {}".format(key, val[0], val[1]),
file=f, file=f,
) )
@ -657,10 +674,12 @@ def save_results(
note = "" note = ""
logging.info(s) logging.info(s)
s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) s = "\nFor {}, (start, end) symbol-delay (s) of different settings are:\n".format(
test_set_name
)
note = "\tbest for {}".format(test_set_name) note = "\tbest for {}".format(test_set_name)
for key, val in test_set_delays: for key, val in test_set_delays:
s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) s += "{}\tmean: {}, variance: {}{}\n".format(key, val[0], val[1], note)
note = "" note = ""
logging.info(s) logging.info(s)

View File

@ -1,5 +1,6 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Mingshuang Luo) # Mingshuang Luo,
# Zengwei Yao)
# #
# See ../../LICENSE for clarification regarding multiple authors # See ../../LICENSE for clarification regarding multiple authors
# #
@ -453,9 +454,30 @@ def store_transcripts_and_timestamps(
for cut_id, ref, hyp, time_ref, time_hyp in texts: for cut_id, ref, hyp, time_ref, time_hyp in texts:
print(f"{cut_id}:\tref={ref}", file=f) print(f"{cut_id}:\tref={ref}", file=f)
print(f"{cut_id}:\thyp={hyp}", file=f) print(f"{cut_id}:\thyp={hyp}", file=f)
if len(time_ref) > 0: if len(time_ref) > 0:
if isinstance(time_ref[0], tuple):
# each element is <start, end> pair
s = (
"["
+ ", ".join(["(%0.3f, %.03f)" % (i, j) for (i, j) in time_ref])
+ "]"
)
else:
# each element is a float number
s = "[" + ", ".join(["%0.3f" % i for i in time_ref]) + "]" s = "[" + ", ".join(["%0.3f" % i for i in time_ref]) + "]"
print(f"{cut_id}:\ttimestamp_ref={s}", file=f) print(f"{cut_id}:\ttimestamp_ref={s}", file=f)
if len(time_hyp) > 0:
if isinstance(time_hyp[0], tuple):
# each element is <start, end> pair
s = (
"["
+ ", ".join(["(%0.3f, %.03f)" % (i, j) for (i, j) in time_hyp])
+ "]"
)
else:
# each element is a float number
s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]" s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]"
print(f"{cut_id}:\ttimestamp_hyp={s}", file=f) print(f"{cut_id}:\ttimestamp_hyp={s}", file=f)
@ -624,9 +646,18 @@ def write_error_stats(
def write_error_stats_with_timestamps( def write_error_stats_with_timestamps(
f: TextIO, f: TextIO,
test_set_name: str, test_set_name: str,
results: List[Tuple[str, List[str], List[str], List[float], List[float]]], results: List[
Tuple[
str,
List[str],
List[str],
List[Union[float, Tuple[float, float]]],
List[Union[float, Tuple[float, float]]],
]
],
enable_log: bool = True, enable_log: bool = True,
) -> Tuple[float, float, float]: with_end_time: bool = False,
) -> Tuple[float, Union[float, Tuple[float, float]], Union[float, Tuple[float, float]]]:
"""Write statistics based on predicted results and reference transcripts """Write statistics based on predicted results and reference transcripts
as well as their timestamps. as well as their timestamps.
@ -659,6 +690,8 @@ def write_error_stats_with_timestamps(
enable_log: enable_log:
If True, also print detailed WER to the console. If True, also print detailed WER to the console.
Otherwise, it is written only to the given file. Otherwise, it is written only to the given file.
with_end_time:
Whether use end timestamps.
Returns: Returns:
Return total word error rate and mean delay. Return total word error rate and mean delay.
@ -704,6 +737,14 @@ def write_error_stats_with_timestamps(
words[ref_word][0] += 1 words[ref_word][0] += 1
num_corr += 1 num_corr += 1
if has_time: if has_time:
if with_end_time:
all_delay.append(
(
time_hyp[p_hyp][0] - time_ref[p_ref][0],
time_hyp[p_hyp][1] - time_ref[p_ref][1],
)
)
else:
all_delay.append(time_hyp[p_hyp] - time_ref[p_ref]) all_delay.append(time_hyp[p_hyp] - time_ref[p_ref])
p_hyp += 1 p_hyp += 1
p_ref += 1 p_ref += 1
@ -716,16 +757,39 @@ def write_error_stats_with_timestamps(
ins_errs = sum(ins.values()) ins_errs = sum(ins.values())
del_errs = sum(dels.values()) del_errs = sum(dels.values())
tot_errs = sub_errs + ins_errs + del_errs tot_errs = sub_errs + ins_errs + del_errs
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) tot_err_rate = float("%.2f" % (100.0 * tot_errs / ref_len))
mean_delay = "inf" if with_end_time:
var_delay = "inf" mean_delay = (float("inf"), float("inf"))
var_delay = (float("inf"), float("inf"))
else:
mean_delay = float("inf")
var_delay = float("inf")
num_delay = len(all_delay) num_delay = len(all_delay)
if num_delay > 0: if num_delay > 0:
if with_end_time:
all_delay_start = [i[0] for i in all_delay]
mean_delay_start = sum(all_delay_start) / num_delay
var_delay_start = (
sum([(i - mean_delay_start) ** 2 for i in all_delay_start]) / num_delay
)
all_delay_end = [i[1] for i in all_delay]
mean_delay_end = sum(all_delay_end) / num_delay
var_delay_end = (
sum([(i - mean_delay_end) ** 2 for i in all_delay_end]) / num_delay
)
mean_delay = (
float("%.3f" % mean_delay_start),
float("%.3f" % mean_delay_end),
)
var_delay = (float("%.3f" % var_delay_start), float("%.3f" % var_delay_end))
else:
mean_delay = sum(all_delay) / num_delay mean_delay = sum(all_delay) / num_delay
var_delay = sum([(i - mean_delay) ** 2 for i in all_delay]) / num_delay var_delay = sum([(i - mean_delay) ** 2 for i in all_delay]) / num_delay
mean_delay = "%.3f" % mean_delay mean_delay = float("%.3f" % mean_delay)
var_delay = "%.3f" % var_delay var_delay = float("%.3f" % var_delay)
if enable_log: if enable_log:
logging.info( logging.info(
@ -734,7 +798,8 @@ def write_error_stats_with_timestamps(
f"{del_errs} del, {sub_errs} sub ]" f"{del_errs} del, {sub_errs} sub ]"
) )
logging.info( logging.info(
f"[{test_set_name}] %symbol-delay mean: {mean_delay}s, variance: {var_delay} " # noqa f"[{test_set_name}] %symbol-delay mean (s): "
f"{mean_delay}, variance: {var_delay} " # noqa
f"computed on {num_delay} correct words" f"computed on {num_delay} correct words"
) )
@ -817,7 +882,8 @@ def write_error_stats_with_timestamps(
hyp_count = corr + hyp_sub + ins hyp_count = corr + hyp_sub + ins
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
return float(tot_err_rate), float(mean_delay), float(var_delay)
return tot_err_rate, mean_delay, var_delay
class MetricsTracker(collections.defaultdict): class MetricsTracker(collections.defaultdict):
@ -1431,3 +1497,270 @@ def filter_uneven_sized_batch(batch: dict, allowed_max_frames: int):
batch["supervisions"][k] = v[:keep_num_utt] batch["supervisions"][k] = v[:keep_num_utt]
return batch return batch
def parse_bpe_start_end_pairs(
tokens: List[str], is_first_token: List[bool]
) -> List[Tuple[int, int]]:
"""Parse pairs of start and end frame indexes for each word.
Args:
tokens:
List of BPE tokens.
is_first_token:
List of bool values, which indicates whether it is the first token,
i.e., not repeat or blank.
Returns:
List of (start-frame-index, end-frame-index) pairs for each word.
"""
assert len(tokens) == len(is_first_token), (len(tokens), len(is_first_token))
start_token = b"\xe2\x96\x81".decode() # '_'
blank_token = "<blk>"
non_blank_idx = [i for i in range(len(tokens)) if tokens[i] != blank_token]
num_non_blank = len(non_blank_idx)
pairs = []
start = -1
end = -1
for j in range(num_non_blank):
# The index in all frames
i = non_blank_idx[j]
found_start = False
if is_first_token[i] and (j == 0 or tokens[i].startswith(start_token)):
found_start = True
if tokens[i] == start_token:
if j == num_non_blank - 1:
# It is the last non-blank token
found_start = False
elif is_first_token[non_blank_idx[j + 1]] and tokens[
non_blank_idx[j + 1]
].startswith(start_token):
# The next not-blank token is a first-token and also starts with start_token
found_start = False
if found_start:
start = i
if start != -1:
found_end = False
if j == num_non_blank - 1:
# It is the last non-blank token
found_end = True
elif is_first_token[non_blank_idx[j + 1]] and tokens[
non_blank_idx[j + 1]
].startswith(start_token):
# The next not-blank token is a first-token and also starts with start_token
found_end = True
if found_end:
end = i
if start != -1 and end != -1:
if not all([tokens[t] == start_token for t in range(start, end + 1)]):
# except the case of all start_token
pairs.append((start, end))
# Reset start and end
start = -1
end = -1
return pairs
def parse_bpe_timestamps_and_texts(
best_paths: k2.Fsa, sp: spm.SentencePieceProcessor
) -> Tuple[List[Tuple[int, int]], List[List[str]]]:
"""Parse timestamps (frame indexes) and texts.
Args:
best_paths:
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
containing multiple FSAs, which is expected to be the result
of k2.shortest_path (otherwise the returned values won't
be meaningful). Its attribtutes `labels` and `aux_labels`
are both BPE tokens.
sp:
The BPE model.
Returns:
utt_index_pairs:
A list of pair list. utt_index_pairs[i] is a list of
(start-frame-index, end-frame-index) pairs for each word in
utterance-i.
utt_words:
A list of str list. utt_words[i] is a word list of utterence-i.
"""
shape = best_paths.arcs.shape().remove_axis(1)
# labels: [utt][arcs]
labels = k2.RaggedTensor(shape, best_paths.labels.contiguous())
# remove -1's.
labels = labels.remove_values_eq(-1)
labels = labels.tolist()
# aux_labels: [utt][arcs]
aux_labels = k2.RaggedTensor(shape, best_paths.aux_labels.contiguous())
# remove -1's.
all_aux_labels = aux_labels.remove_values_eq(-1)
# len(all_aux_labels[i]) is equal to the number of frames
all_aux_labels = all_aux_labels.tolist()
# remove 0's and -1's.
out_aux_labels = aux_labels.remove_values_leq(0)
# len(out_aux_labels[i]) is equal to the number of output BPE tokens
out_aux_labels = out_aux_labels.tolist()
utt_index_pairs = []
utt_words = []
for i in range(len(labels)):
tokens = sp.id_to_piece(labels[i])
words = sp.decode(out_aux_labels[i]).split()
# Indicates whether it is the first token, i.e., not-repeat and not-blank.
is_first_token = [a != 0 for a in all_aux_labels[i]]
index_pairs = parse_bpe_start_end_pairs(tokens, is_first_token)
assert len(index_pairs) == len(words), (len(index_pairs), len(words), tokens)
utt_index_pairs.append(index_pairs)
utt_words.append(words)
return utt_index_pairs, utt_words
def parse_timestamps_and_texts(
best_paths: k2.Fsa, word_table: k2.SymbolTable
) -> Tuple[List[Tuple[int, int]], List[List[str]]]:
"""Parse timestamps (frame indexes) and texts.
Args:
best_paths:
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
containing multiple FSAs, which is expected to be the result
of k2.shortest_path (otherwise the returned values won't
be meaningful). Attribtute `labels` is the prediction unit,
e.g., phone or BPE tokens. Attribute `aux_labels` is the word index.
word_table:
The word symbol table.
Returns:
utt_index_pairs:
A list of pair list. utt_index_pairs[i] is a list of
(start-frame-index, end-frame-index) pairs for each word in
utterance-i.
utt_words:
A list of str list. utt_words[i] is a word list of utterence-i.
"""
# [utt][words]
word_ids = get_texts(best_paths)
shape = best_paths.arcs.shape().remove_axis(1)
# labels: [utt][arcs]
labels = k2.RaggedTensor(shape, best_paths.labels.contiguous())
# remove -1's.
labels = labels.remove_values_eq(-1)
labels = labels.tolist()
# aux_labels: [utt][arcs]
aux_shape = shape.compose(best_paths.aux_labels.shape)
aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels.values.contiguous())
aux_labels = aux_labels.tolist()
utt_index_pairs = []
utt_words = []
for i, (label, aux_label) in enumerate(zip(labels, aux_labels)):
num_arcs = len(label)
# The last arc of aux_label is the arc entering the final state
assert num_arcs == len(aux_label) - 1, (num_arcs, len(aux_label))
index_pairs = []
start = -1
end = -1
for arc in range(num_arcs):
# len(aux_label[arc]) is 0 or 1
if label[arc] != 0 and len(aux_label[arc]) != 0:
if start != -1 and end != -1:
index_pairs.append((start, end))
start = arc
if label[arc] != 0:
end = arc
if start != -1 and end != -1:
index_pairs.append((start, end))
words = [word_table[w] for w in word_ids[i]]
assert len(index_pairs) == len(words), (len(index_pairs), len(words))
utt_index_pairs.append(index_pairs)
utt_words.append(words)
return utt_index_pairs, utt_words
def parse_fsa_timestamps_and_texts(
best_paths: k2.Fsa,
sp: Optional[spm.SentencePieceProcessor] = None,
word_table: Optional[k2.SymbolTable] = None,
subsampling_factor: int = 4,
frame_shift_ms: float = 10,
) -> Tuple[List[Tuple[float, float]], List[List[str]]]:
"""Parse timestamps (in seconds) and texts for given decoded fsa paths.
Currently it supports two cases:
(1) ctc-decoding, the attribtutes `labels` and `aux_labels`
are both BPE tokens. In this case, sp should be provided.
(2) HLG-based 1best, the attribtute `labels` is the prediction unit,
e.g., phone or BPE tokens; attribute `aux_labels` is the word index.
In this case, word_table should be provided.
Args:
best_paths:
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
containing multiple FSAs, which is expected to be the result
of k2.shortest_path (otherwise the returned values won't
be meaningful).
sp:
The BPE model.
word_table:
The word symbol table.
subsampling_factor:
The subsampling factor of the model.
frame_shift_ms:
Frame shift in milliseconds between two contiguous frames.
Returns:
utt_time_pairs:
A list of pair list. utt_time_pairs[i] is a list of
(start-time, end-time) pairs for each word in
utterance-i.
utt_words:
A list of str list. utt_words[i] is a word list of utterence-i.
"""
if sp is not None:
assert word_table is None, "word_table is not needed if sp is provided."
utt_index_pairs, utt_words = parse_bpe_timestamps_and_texts(
best_paths=best_paths, sp=sp
)
elif word_table is not None:
assert sp is None, "sp is not needed if word_table is provided."
utt_index_pairs, utt_words = parse_timestamps_and_texts(
best_paths=best_paths, word_table=word_table
)
else:
raise ValueError("Either sp or word_table should be provided.")
utt_time_pairs = []
for utt in utt_index_pairs:
start = convert_timestamp(
frames=[i[0] for i in utt],
subsampling_factor=subsampling_factor,
frame_shift_ms=frame_shift_ms,
)
end = convert_timestamp(
# The duration in frames is (end_frame_index - start_frame_index + 1)
frames=[i[1] + 1 for i in utt],
subsampling_factor=subsampling_factor,
frame_shift_ms=frame_shift_ms,
)
utt_time_pairs.append(list(zip(start, end)))
return utt_time_pairs, utt_words

154
test/test_parse_timestamp.py Executable file
View File

@ -0,0 +1,154 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import k2
import sentencepiece as spm
import torch
from icefall.lexicon import Lexicon
from icefall.utils import parse_bpe_timestamps_and_texts, parse_timestamps_and_texts
ICEFALL_DIR = Path(__file__).resolve().parent.parent
def test_parse_bpe_timestamps_and_texts():
lang_dir = ICEFALL_DIR / "egs/librispeech/ASR/data/lang_bpe_500"
if not lang_dir.is_dir():
print(f"{lang_dir} does not exist.")
return
sp = spm.SentencePieceProcessor()
sp.load(str(lang_dir / "bpe.model"))
text_1 = "HELLO WORLD"
token_ids_1 = sp.encode(text_1, out_type=int)
# out_type=str: ['_HE', 'LL', 'O', '_WORLD']
# out_type=int: [22, 58, 24, 425]
# [22, 22, 58, 24, 0, 0, 425, 425, 425, 0, 0]
labels_1 = (
token_ids_1[0:1] * 2
+ token_ids_1[1:3]
+ [0] * 2
+ token_ids_1[3:4] * 3
+ [0] * 2
)
# [22, 0, 58, 24, 0, 0, 425, 0, 0, 0, 0, -1]
aux_labels_1 = (
token_ids_1[0:1]
+ [0]
+ token_ids_1[1:3]
+ [0] * 2
+ token_ids_1[3:4]
+ [0] * 4
+ [-1]
)
fsa_1 = k2.linear_fsa(labels_1)
fsa_1.aux_labels = torch.tensor(aux_labels_1).to(torch.int32)
text_2 = "SAY GOODBYE"
token_ids_2 = sp.encode(text_2, out_type=int)
# out_type=str: ['_SAY', '_GOOD', 'B', 'Y', 'E']
# out_type=int: [289, 286, 41, 16, 11]
# [289, 0, 0, 286, 286, 41, 16, 11, 0, 0]
labels_2 = (
token_ids_2[0:1] + [0] * 2 + token_ids_2[1:2] * 2 + token_ids_2[2:5] + [0] * 2
)
# [289, 0, 0, 286, 0, 41, 16, 11, 0, 0, -1]
aux_labels_2 = (
token_ids_2[0:1]
+ [0] * 2
+ token_ids_2[1:2]
+ [0]
+ token_ids_2[2:5]
+ [0] * 2
+ [-1]
)
fsa_2 = k2.linear_fsa(labels_2)
fsa_2.aux_labels = torch.tensor(aux_labels_2).to(torch.int32)
fsa_vec = k2.create_fsa_vec([fsa_1, fsa_2])
utt_index_pairs, utt_words = parse_bpe_timestamps_and_texts(fsa_vec, sp)
assert utt_index_pairs[0] == [(0, 3), (6, 8)], utt_index_pairs[0]
assert utt_words[0] == ["HELLO", "WORLD"], utt_words[0]
assert utt_index_pairs[1] == [(0, 0), (3, 7)], utt_index_pairs[1]
assert utt_words[1] == ["SAY", "GOODBYE"], utt_words[1]
def test_parse_timestamps_and_texts():
lang_dir = ICEFALL_DIR / "egs/librispeech/ASR/data/lang_bpe_500"
if not lang_dir.is_dir():
print(f"{lang_dir} does not exist.")
return
lexicon = Lexicon(lang_dir)
sp = spm.SentencePieceProcessor()
sp.load(str(lang_dir / "bpe.model"))
word_table = lexicon.word_table
text_1 = "HELLO WORLD"
token_ids_1 = sp.encode(text_1, out_type=int)
# out_type=str: ['_HE', 'LL', 'O', '_WORLD']
# out_type=int: [22, 58, 24, 425]
word_ids_1 = [word_table[s] for s in text_1.split()] # [79677, 196937]
# [22, 22, 58, 24, 0, 0, 425, 425, 425, 0, 0]
labels_1 = (
token_ids_1[0:1] * 2
+ token_ids_1[1:3]
+ [0] * 2
+ token_ids_1[3:4] * 3
+ [0] * 2
)
# [[79677], [], [], [], [], [], [196937], [], [], [], [], []]
aux_labels_1 = [word_ids_1[0:1]] + [[]] * 5 + [word_ids_1[1:2]] + [[]] * 5
fsa_1 = k2.linear_fsa(labels_1)
fsa_1.aux_labels = k2.RaggedTensor(aux_labels_1)
text_2 = "SAY GOODBYE"
token_ids_2 = sp.encode(text_2, out_type=int)
# out_type=str: ['_SAY', '_GOOD', 'B', 'Y', 'E']
# out_type=int: [289, 286, 41, 16, 11]
word_ids_2 = [word_table[s] for s in text_2.split()] # [154967, 72079]
# [289, 0, 0, 286, 286, 41, 16, 11, 0, 0]
labels_2 = (
token_ids_2[0:1] + [0] * 2 + token_ids_2[1:2] * 2 + token_ids_2[2:5] + [0] * 2
)
# [[154967], [], [], [72079], [], [], [], [], [], [], []]
aux_labels_2 = [word_ids_2[0:1]] + [[]] * 2 + [word_ids_2[1:2]] + [[]] * 7
fsa_2 = k2.linear_fsa(labels_2)
fsa_2.aux_labels = k2.RaggedTensor(aux_labels_2)
fsa_vec = k2.create_fsa_vec([fsa_1, fsa_2])
utt_index_pairs, utt_words = parse_timestamps_and_texts(fsa_vec, word_table)
assert utt_index_pairs[0] == [(0, 3), (6, 8)], utt_index_pairs[0]
assert utt_words[0] == ["HELLO", "WORLD"], utt_words[0]
assert utt_index_pairs[1] == [(0, 0), (3, 7)], utt_index_pairs[1]
assert utt_words[1] == ["SAY", "GOODBYE"], utt_words[1]
if __name__ == "__main__":
test_parse_bpe_timestamps_and_texts()
test_parse_timestamps_and_texts()