mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-11 11:02:29 +00:00
Merge branch 'k2-fsa:master' into repeat-k
This commit is contained in:
commit
e3beb93e1d
@ -96,8 +96,7 @@ from icefall.lexicon import Lexicon
|
|||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
get_texts,
|
get_texts,
|
||||||
get_texts_with_timestamp,
|
parse_fsa_timestamps_and_texts,
|
||||||
parse_hyp_and_timestamp,
|
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts_and_timestamps,
|
store_transcripts_and_timestamps,
|
||||||
str2bool,
|
str2bool,
|
||||||
@ -396,13 +395,8 @@ def decode_one_batch(
|
|||||||
best_path = one_best_decoding(
|
best_path = one_best_decoding(
|
||||||
lattice=lattice, use_double_scores=params.use_double_scores
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
)
|
)
|
||||||
# Note: `best_path.aux_labels` contains token IDs, not word IDs
|
timestamps, hyps = parse_fsa_timestamps_and_texts(
|
||||||
# since we are using H, not HLG here.
|
best_paths=best_path,
|
||||||
#
|
|
||||||
# token_ids is a lit-of-list of IDs
|
|
||||||
res = get_texts_with_timestamp(best_path)
|
|
||||||
hyps, timestamps = parse_hyp_and_timestamp(
|
|
||||||
res=res,
|
|
||||||
sp=bpe_model,
|
sp=bpe_model,
|
||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
frame_shift_ms=params.frame_shift_ms,
|
frame_shift_ms=params.frame_shift_ms,
|
||||||
@ -435,12 +429,11 @@ def decode_one_batch(
|
|||||||
lattice=lattice, use_double_scores=params.use_double_scores
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
)
|
)
|
||||||
key = f"no_rescore_hlg_scale_{params.hlg_scale}"
|
key = f"no_rescore_hlg_scale_{params.hlg_scale}"
|
||||||
res = get_texts_with_timestamp(best_path)
|
timestamps, hyps = parse_fsa_timestamps_and_texts(
|
||||||
hyps, timestamps = parse_hyp_and_timestamp(
|
best_paths=best_path,
|
||||||
res=res,
|
word_table=word_table,
|
||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
frame_shift_ms=params.frame_shift_ms,
|
frame_shift_ms=params.frame_shift_ms,
|
||||||
word_table=word_table,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
best_path = nbest_decoding(
|
best_path = nbest_decoding(
|
||||||
@ -504,7 +497,18 @@ def decode_dataset(
|
|||||||
sos_id: int,
|
sos_id: int,
|
||||||
eos_id: int,
|
eos_id: int,
|
||||||
G: Optional[k2.Fsa] = None,
|
G: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]:
|
) -> Dict[
|
||||||
|
str,
|
||||||
|
List[
|
||||||
|
Tuple[
|
||||||
|
str,
|
||||||
|
List[str],
|
||||||
|
List[str],
|
||||||
|
List[Tuple[float, float]],
|
||||||
|
List[Tuple[float, float]],
|
||||||
|
]
|
||||||
|
],
|
||||||
|
]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -555,7 +559,7 @@ def decode_dataset(
|
|||||||
time = []
|
time = []
|
||||||
if s.alignment is not None and "word" in s.alignment:
|
if s.alignment is not None and "word" in s.alignment:
|
||||||
time = [
|
time = [
|
||||||
aliword.start
|
(aliword.start, aliword.end)
|
||||||
for aliword in s.alignment["word"]
|
for aliword in s.alignment["word"]
|
||||||
if aliword.symbol != ""
|
if aliword.symbol != ""
|
||||||
]
|
]
|
||||||
@ -601,7 +605,15 @@ def save_results(
|
|||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
results_dict: Dict[
|
results_dict: Dict[
|
||||||
str,
|
str,
|
||||||
List[Tuple[List[str], List[str], List[str], List[float], List[float]]],
|
List[
|
||||||
|
Tuple[
|
||||||
|
List[str],
|
||||||
|
List[str],
|
||||||
|
List[str],
|
||||||
|
List[Tuple[float, float]],
|
||||||
|
List[Tuple[float, float]],
|
||||||
|
]
|
||||||
|
],
|
||||||
],
|
],
|
||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
@ -621,7 +633,11 @@ def save_results(
|
|||||||
)
|
)
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer, mean_delay, var_delay = write_error_stats_with_timestamps(
|
wer, mean_delay, var_delay = write_error_stats_with_timestamps(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f,
|
||||||
|
f"{test_set_name}-{key}",
|
||||||
|
results,
|
||||||
|
enable_log=True,
|
||||||
|
with_end_time=True,
|
||||||
)
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
test_set_delays[key] = (mean_delay, var_delay)
|
test_set_delays[key] = (mean_delay, var_delay)
|
||||||
@ -637,16 +653,17 @@ def save_results(
|
|||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
print("{}\t{}".format(key, val), file=f)
|
print("{}\t{}".format(key, val), file=f)
|
||||||
|
|
||||||
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0])
|
# sort according to the mean start symbol delay
|
||||||
|
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0][0])
|
||||||
delays_info = (
|
delays_info = (
|
||||||
params.res_dir
|
params.res_dir
|
||||||
/ f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
/ f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
with open(delays_info, "w") as f:
|
with open(delays_info, "w") as f:
|
||||||
print("settings\tsymbol-delay", file=f)
|
print("settings\t(start, end) symbol-delay (s) (start, end)", file=f)
|
||||||
for key, val in test_set_delays:
|
for key, val in test_set_delays:
|
||||||
print(
|
print(
|
||||||
"{}\tmean: {}s, variance: {}".format(key, val[0], val[1]),
|
"{}\tmean: {}, variance: {}".format(key, val[0], val[1]),
|
||||||
file=f,
|
file=f,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -657,10 +674,12 @@ def save_results(
|
|||||||
note = ""
|
note = ""
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name)
|
s = "\nFor {}, (start, end) symbol-delay (s) of different settings are:\n".format(
|
||||||
|
test_set_name
|
||||||
|
)
|
||||||
note = "\tbest for {}".format(test_set_name)
|
note = "\tbest for {}".format(test_set_name)
|
||||||
for key, val in test_set_delays:
|
for key, val in test_set_delays:
|
||||||
s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note)
|
s += "{}\tmean: {}, variance: {}{}\n".format(key, val[0], val[1], note)
|
||||||
note = ""
|
note = ""
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
|
367
icefall/utils.py
367
icefall/utils.py
@ -1,5 +1,6 @@
|
|||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
# Mingshuang Luo)
|
# Mingshuang Luo,
|
||||||
|
# Zengwei Yao)
|
||||||
#
|
#
|
||||||
# See ../../LICENSE for clarification regarding multiple authors
|
# See ../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -453,11 +454,32 @@ def store_transcripts_and_timestamps(
|
|||||||
for cut_id, ref, hyp, time_ref, time_hyp in texts:
|
for cut_id, ref, hyp, time_ref, time_hyp in texts:
|
||||||
print(f"{cut_id}:\tref={ref}", file=f)
|
print(f"{cut_id}:\tref={ref}", file=f)
|
||||||
print(f"{cut_id}:\thyp={hyp}", file=f)
|
print(f"{cut_id}:\thyp={hyp}", file=f)
|
||||||
|
|
||||||
if len(time_ref) > 0:
|
if len(time_ref) > 0:
|
||||||
s = "[" + ", ".join(["%0.3f" % i for i in time_ref]) + "]"
|
if isinstance(time_ref[0], tuple):
|
||||||
|
# each element is <start, end> pair
|
||||||
|
s = (
|
||||||
|
"["
|
||||||
|
+ ", ".join(["(%0.3f, %.03f)" % (i, j) for (i, j) in time_ref])
|
||||||
|
+ "]"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# each element is a float number
|
||||||
|
s = "[" + ", ".join(["%0.3f" % i for i in time_ref]) + "]"
|
||||||
print(f"{cut_id}:\ttimestamp_ref={s}", file=f)
|
print(f"{cut_id}:\ttimestamp_ref={s}", file=f)
|
||||||
s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]"
|
|
||||||
print(f"{cut_id}:\ttimestamp_hyp={s}", file=f)
|
if len(time_hyp) > 0:
|
||||||
|
if isinstance(time_hyp[0], tuple):
|
||||||
|
# each element is <start, end> pair
|
||||||
|
s = (
|
||||||
|
"["
|
||||||
|
+ ", ".join(["(%0.3f, %.03f)" % (i, j) for (i, j) in time_hyp])
|
||||||
|
+ "]"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# each element is a float number
|
||||||
|
s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]"
|
||||||
|
print(f"{cut_id}:\ttimestamp_hyp={s}", file=f)
|
||||||
|
|
||||||
|
|
||||||
def write_error_stats(
|
def write_error_stats(
|
||||||
@ -624,9 +646,18 @@ def write_error_stats(
|
|||||||
def write_error_stats_with_timestamps(
|
def write_error_stats_with_timestamps(
|
||||||
f: TextIO,
|
f: TextIO,
|
||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
results: List[Tuple[str, List[str], List[str], List[float], List[float]]],
|
results: List[
|
||||||
|
Tuple[
|
||||||
|
str,
|
||||||
|
List[str],
|
||||||
|
List[str],
|
||||||
|
List[Union[float, Tuple[float, float]]],
|
||||||
|
List[Union[float, Tuple[float, float]]],
|
||||||
|
]
|
||||||
|
],
|
||||||
enable_log: bool = True,
|
enable_log: bool = True,
|
||||||
) -> Tuple[float, float, float]:
|
with_end_time: bool = False,
|
||||||
|
) -> Tuple[float, Union[float, Tuple[float, float]], Union[float, Tuple[float, float]]]:
|
||||||
"""Write statistics based on predicted results and reference transcripts
|
"""Write statistics based on predicted results and reference transcripts
|
||||||
as well as their timestamps.
|
as well as their timestamps.
|
||||||
|
|
||||||
@ -659,6 +690,8 @@ def write_error_stats_with_timestamps(
|
|||||||
enable_log:
|
enable_log:
|
||||||
If True, also print detailed WER to the console.
|
If True, also print detailed WER to the console.
|
||||||
Otherwise, it is written only to the given file.
|
Otherwise, it is written only to the given file.
|
||||||
|
with_end_time:
|
||||||
|
Whether use end timestamps.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Return total word error rate and mean delay.
|
Return total word error rate and mean delay.
|
||||||
@ -704,7 +737,15 @@ def write_error_stats_with_timestamps(
|
|||||||
words[ref_word][0] += 1
|
words[ref_word][0] += 1
|
||||||
num_corr += 1
|
num_corr += 1
|
||||||
if has_time:
|
if has_time:
|
||||||
all_delay.append(time_hyp[p_hyp] - time_ref[p_ref])
|
if with_end_time:
|
||||||
|
all_delay.append(
|
||||||
|
(
|
||||||
|
time_hyp[p_hyp][0] - time_ref[p_ref][0],
|
||||||
|
time_hyp[p_hyp][1] - time_ref[p_ref][1],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
all_delay.append(time_hyp[p_hyp] - time_ref[p_ref])
|
||||||
p_hyp += 1
|
p_hyp += 1
|
||||||
p_ref += 1
|
p_ref += 1
|
||||||
if has_time:
|
if has_time:
|
||||||
@ -716,16 +757,39 @@ def write_error_stats_with_timestamps(
|
|||||||
ins_errs = sum(ins.values())
|
ins_errs = sum(ins.values())
|
||||||
del_errs = sum(dels.values())
|
del_errs = sum(dels.values())
|
||||||
tot_errs = sub_errs + ins_errs + del_errs
|
tot_errs = sub_errs + ins_errs + del_errs
|
||||||
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
tot_err_rate = float("%.2f" % (100.0 * tot_errs / ref_len))
|
||||||
|
|
||||||
mean_delay = "inf"
|
if with_end_time:
|
||||||
var_delay = "inf"
|
mean_delay = (float("inf"), float("inf"))
|
||||||
|
var_delay = (float("inf"), float("inf"))
|
||||||
|
else:
|
||||||
|
mean_delay = float("inf")
|
||||||
|
var_delay = float("inf")
|
||||||
num_delay = len(all_delay)
|
num_delay = len(all_delay)
|
||||||
if num_delay > 0:
|
if num_delay > 0:
|
||||||
mean_delay = sum(all_delay) / num_delay
|
if with_end_time:
|
||||||
var_delay = sum([(i - mean_delay) ** 2 for i in all_delay]) / num_delay
|
all_delay_start = [i[0] for i in all_delay]
|
||||||
mean_delay = "%.3f" % mean_delay
|
mean_delay_start = sum(all_delay_start) / num_delay
|
||||||
var_delay = "%.3f" % var_delay
|
var_delay_start = (
|
||||||
|
sum([(i - mean_delay_start) ** 2 for i in all_delay_start]) / num_delay
|
||||||
|
)
|
||||||
|
|
||||||
|
all_delay_end = [i[1] for i in all_delay]
|
||||||
|
mean_delay_end = sum(all_delay_end) / num_delay
|
||||||
|
var_delay_end = (
|
||||||
|
sum([(i - mean_delay_end) ** 2 for i in all_delay_end]) / num_delay
|
||||||
|
)
|
||||||
|
|
||||||
|
mean_delay = (
|
||||||
|
float("%.3f" % mean_delay_start),
|
||||||
|
float("%.3f" % mean_delay_end),
|
||||||
|
)
|
||||||
|
var_delay = (float("%.3f" % var_delay_start), float("%.3f" % var_delay_end))
|
||||||
|
else:
|
||||||
|
mean_delay = sum(all_delay) / num_delay
|
||||||
|
var_delay = sum([(i - mean_delay) ** 2 for i in all_delay]) / num_delay
|
||||||
|
mean_delay = float("%.3f" % mean_delay)
|
||||||
|
var_delay = float("%.3f" % var_delay)
|
||||||
|
|
||||||
if enable_log:
|
if enable_log:
|
||||||
logging.info(
|
logging.info(
|
||||||
@ -734,7 +798,8 @@ def write_error_stats_with_timestamps(
|
|||||||
f"{del_errs} del, {sub_errs} sub ]"
|
f"{del_errs} del, {sub_errs} sub ]"
|
||||||
)
|
)
|
||||||
logging.info(
|
logging.info(
|
||||||
f"[{test_set_name}] %symbol-delay mean: {mean_delay}s, variance: {var_delay} " # noqa
|
f"[{test_set_name}] %symbol-delay mean (s): "
|
||||||
|
f"{mean_delay}, variance: {var_delay} " # noqa
|
||||||
f"computed on {num_delay} correct words"
|
f"computed on {num_delay} correct words"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -817,7 +882,8 @@ def write_error_stats_with_timestamps(
|
|||||||
hyp_count = corr + hyp_sub + ins
|
hyp_count = corr + hyp_sub + ins
|
||||||
|
|
||||||
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||||
return float(tot_err_rate), float(mean_delay), float(var_delay)
|
|
||||||
|
return tot_err_rate, mean_delay, var_delay
|
||||||
|
|
||||||
|
|
||||||
class MetricsTracker(collections.defaultdict):
|
class MetricsTracker(collections.defaultdict):
|
||||||
@ -1431,3 +1497,270 @@ def filter_uneven_sized_batch(batch: dict, allowed_max_frames: int):
|
|||||||
batch["supervisions"][k] = v[:keep_num_utt]
|
batch["supervisions"][k] = v[:keep_num_utt]
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def parse_bpe_start_end_pairs(
|
||||||
|
tokens: List[str], is_first_token: List[bool]
|
||||||
|
) -> List[Tuple[int, int]]:
|
||||||
|
"""Parse pairs of start and end frame indexes for each word.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens:
|
||||||
|
List of BPE tokens.
|
||||||
|
is_first_token:
|
||||||
|
List of bool values, which indicates whether it is the first token,
|
||||||
|
i.e., not repeat or blank.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (start-frame-index, end-frame-index) pairs for each word.
|
||||||
|
"""
|
||||||
|
assert len(tokens) == len(is_first_token), (len(tokens), len(is_first_token))
|
||||||
|
|
||||||
|
start_token = b"\xe2\x96\x81".decode() # '_'
|
||||||
|
blank_token = "<blk>"
|
||||||
|
|
||||||
|
non_blank_idx = [i for i in range(len(tokens)) if tokens[i] != blank_token]
|
||||||
|
num_non_blank = len(non_blank_idx)
|
||||||
|
|
||||||
|
pairs = []
|
||||||
|
start = -1
|
||||||
|
end = -1
|
||||||
|
for j in range(num_non_blank):
|
||||||
|
# The index in all frames
|
||||||
|
i = non_blank_idx[j]
|
||||||
|
|
||||||
|
found_start = False
|
||||||
|
if is_first_token[i] and (j == 0 or tokens[i].startswith(start_token)):
|
||||||
|
found_start = True
|
||||||
|
if tokens[i] == start_token:
|
||||||
|
if j == num_non_blank - 1:
|
||||||
|
# It is the last non-blank token
|
||||||
|
found_start = False
|
||||||
|
elif is_first_token[non_blank_idx[j + 1]] and tokens[
|
||||||
|
non_blank_idx[j + 1]
|
||||||
|
].startswith(start_token):
|
||||||
|
# The next not-blank token is a first-token and also starts with start_token
|
||||||
|
found_start = False
|
||||||
|
if found_start:
|
||||||
|
start = i
|
||||||
|
|
||||||
|
if start != -1:
|
||||||
|
found_end = False
|
||||||
|
if j == num_non_blank - 1:
|
||||||
|
# It is the last non-blank token
|
||||||
|
found_end = True
|
||||||
|
elif is_first_token[non_blank_idx[j + 1]] and tokens[
|
||||||
|
non_blank_idx[j + 1]
|
||||||
|
].startswith(start_token):
|
||||||
|
# The next not-blank token is a first-token and also starts with start_token
|
||||||
|
found_end = True
|
||||||
|
if found_end:
|
||||||
|
end = i
|
||||||
|
|
||||||
|
if start != -1 and end != -1:
|
||||||
|
if not all([tokens[t] == start_token for t in range(start, end + 1)]):
|
||||||
|
# except the case of all start_token
|
||||||
|
pairs.append((start, end))
|
||||||
|
# Reset start and end
|
||||||
|
start = -1
|
||||||
|
end = -1
|
||||||
|
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
|
||||||
|
def parse_bpe_timestamps_and_texts(
|
||||||
|
best_paths: k2.Fsa, sp: spm.SentencePieceProcessor
|
||||||
|
) -> Tuple[List[Tuple[int, int]], List[List[str]]]:
|
||||||
|
"""Parse timestamps (frame indexes) and texts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
best_paths:
|
||||||
|
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
||||||
|
containing multiple FSAs, which is expected to be the result
|
||||||
|
of k2.shortest_path (otherwise the returned values won't
|
||||||
|
be meaningful). Its attribtutes `labels` and `aux_labels`
|
||||||
|
are both BPE tokens.
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
utt_index_pairs:
|
||||||
|
A list of pair list. utt_index_pairs[i] is a list of
|
||||||
|
(start-frame-index, end-frame-index) pairs for each word in
|
||||||
|
utterance-i.
|
||||||
|
utt_words:
|
||||||
|
A list of str list. utt_words[i] is a word list of utterence-i.
|
||||||
|
"""
|
||||||
|
shape = best_paths.arcs.shape().remove_axis(1)
|
||||||
|
|
||||||
|
# labels: [utt][arcs]
|
||||||
|
labels = k2.RaggedTensor(shape, best_paths.labels.contiguous())
|
||||||
|
# remove -1's.
|
||||||
|
labels = labels.remove_values_eq(-1)
|
||||||
|
labels = labels.tolist()
|
||||||
|
|
||||||
|
# aux_labels: [utt][arcs]
|
||||||
|
aux_labels = k2.RaggedTensor(shape, best_paths.aux_labels.contiguous())
|
||||||
|
|
||||||
|
# remove -1's.
|
||||||
|
all_aux_labels = aux_labels.remove_values_eq(-1)
|
||||||
|
# len(all_aux_labels[i]) is equal to the number of frames
|
||||||
|
all_aux_labels = all_aux_labels.tolist()
|
||||||
|
|
||||||
|
# remove 0's and -1's.
|
||||||
|
out_aux_labels = aux_labels.remove_values_leq(0)
|
||||||
|
# len(out_aux_labels[i]) is equal to the number of output BPE tokens
|
||||||
|
out_aux_labels = out_aux_labels.tolist()
|
||||||
|
|
||||||
|
utt_index_pairs = []
|
||||||
|
utt_words = []
|
||||||
|
for i in range(len(labels)):
|
||||||
|
tokens = sp.id_to_piece(labels[i])
|
||||||
|
words = sp.decode(out_aux_labels[i]).split()
|
||||||
|
|
||||||
|
# Indicates whether it is the first token, i.e., not-repeat and not-blank.
|
||||||
|
is_first_token = [a != 0 for a in all_aux_labels[i]]
|
||||||
|
index_pairs = parse_bpe_start_end_pairs(tokens, is_first_token)
|
||||||
|
assert len(index_pairs) == len(words), (len(index_pairs), len(words), tokens)
|
||||||
|
utt_index_pairs.append(index_pairs)
|
||||||
|
utt_words.append(words)
|
||||||
|
|
||||||
|
return utt_index_pairs, utt_words
|
||||||
|
|
||||||
|
|
||||||
|
def parse_timestamps_and_texts(
|
||||||
|
best_paths: k2.Fsa, word_table: k2.SymbolTable
|
||||||
|
) -> Tuple[List[Tuple[int, int]], List[List[str]]]:
|
||||||
|
"""Parse timestamps (frame indexes) and texts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
best_paths:
|
||||||
|
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
||||||
|
containing multiple FSAs, which is expected to be the result
|
||||||
|
of k2.shortest_path (otherwise the returned values won't
|
||||||
|
be meaningful). Attribtute `labels` is the prediction unit,
|
||||||
|
e.g., phone or BPE tokens. Attribute `aux_labels` is the word index.
|
||||||
|
word_table:
|
||||||
|
The word symbol table.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
utt_index_pairs:
|
||||||
|
A list of pair list. utt_index_pairs[i] is a list of
|
||||||
|
(start-frame-index, end-frame-index) pairs for each word in
|
||||||
|
utterance-i.
|
||||||
|
utt_words:
|
||||||
|
A list of str list. utt_words[i] is a word list of utterence-i.
|
||||||
|
"""
|
||||||
|
# [utt][words]
|
||||||
|
word_ids = get_texts(best_paths)
|
||||||
|
|
||||||
|
shape = best_paths.arcs.shape().remove_axis(1)
|
||||||
|
|
||||||
|
# labels: [utt][arcs]
|
||||||
|
labels = k2.RaggedTensor(shape, best_paths.labels.contiguous())
|
||||||
|
# remove -1's.
|
||||||
|
labels = labels.remove_values_eq(-1)
|
||||||
|
labels = labels.tolist()
|
||||||
|
|
||||||
|
# aux_labels: [utt][arcs]
|
||||||
|
aux_shape = shape.compose(best_paths.aux_labels.shape)
|
||||||
|
aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels.values.contiguous())
|
||||||
|
aux_labels = aux_labels.tolist()
|
||||||
|
|
||||||
|
utt_index_pairs = []
|
||||||
|
utt_words = []
|
||||||
|
for i, (label, aux_label) in enumerate(zip(labels, aux_labels)):
|
||||||
|
num_arcs = len(label)
|
||||||
|
# The last arc of aux_label is the arc entering the final state
|
||||||
|
assert num_arcs == len(aux_label) - 1, (num_arcs, len(aux_label))
|
||||||
|
|
||||||
|
index_pairs = []
|
||||||
|
start = -1
|
||||||
|
end = -1
|
||||||
|
for arc in range(num_arcs):
|
||||||
|
# len(aux_label[arc]) is 0 or 1
|
||||||
|
if label[arc] != 0 and len(aux_label[arc]) != 0:
|
||||||
|
if start != -1 and end != -1:
|
||||||
|
index_pairs.append((start, end))
|
||||||
|
start = arc
|
||||||
|
if label[arc] != 0:
|
||||||
|
end = arc
|
||||||
|
if start != -1 and end != -1:
|
||||||
|
index_pairs.append((start, end))
|
||||||
|
|
||||||
|
words = [word_table[w] for w in word_ids[i]]
|
||||||
|
assert len(index_pairs) == len(words), (len(index_pairs), len(words))
|
||||||
|
|
||||||
|
utt_index_pairs.append(index_pairs)
|
||||||
|
utt_words.append(words)
|
||||||
|
|
||||||
|
return utt_index_pairs, utt_words
|
||||||
|
|
||||||
|
|
||||||
|
def parse_fsa_timestamps_and_texts(
|
||||||
|
best_paths: k2.Fsa,
|
||||||
|
sp: Optional[spm.SentencePieceProcessor] = None,
|
||||||
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
|
subsampling_factor: int = 4,
|
||||||
|
frame_shift_ms: float = 10,
|
||||||
|
) -> Tuple[List[Tuple[float, float]], List[List[str]]]:
|
||||||
|
"""Parse timestamps (in seconds) and texts for given decoded fsa paths.
|
||||||
|
Currently it supports two cases:
|
||||||
|
(1) ctc-decoding, the attribtutes `labels` and `aux_labels`
|
||||||
|
are both BPE tokens. In this case, sp should be provided.
|
||||||
|
(2) HLG-based 1best, the attribtute `labels` is the prediction unit,
|
||||||
|
e.g., phone or BPE tokens; attribute `aux_labels` is the word index.
|
||||||
|
In this case, word_table should be provided.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
best_paths:
|
||||||
|
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
||||||
|
containing multiple FSAs, which is expected to be the result
|
||||||
|
of k2.shortest_path (otherwise the returned values won't
|
||||||
|
be meaningful).
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
|
word_table:
|
||||||
|
The word symbol table.
|
||||||
|
subsampling_factor:
|
||||||
|
The subsampling factor of the model.
|
||||||
|
frame_shift_ms:
|
||||||
|
Frame shift in milliseconds between two contiguous frames.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
utt_time_pairs:
|
||||||
|
A list of pair list. utt_time_pairs[i] is a list of
|
||||||
|
(start-time, end-time) pairs for each word in
|
||||||
|
utterance-i.
|
||||||
|
utt_words:
|
||||||
|
A list of str list. utt_words[i] is a word list of utterence-i.
|
||||||
|
"""
|
||||||
|
if sp is not None:
|
||||||
|
assert word_table is None, "word_table is not needed if sp is provided."
|
||||||
|
utt_index_pairs, utt_words = parse_bpe_timestamps_and_texts(
|
||||||
|
best_paths=best_paths, sp=sp
|
||||||
|
)
|
||||||
|
elif word_table is not None:
|
||||||
|
assert sp is None, "sp is not needed if word_table is provided."
|
||||||
|
utt_index_pairs, utt_words = parse_timestamps_and_texts(
|
||||||
|
best_paths=best_paths, word_table=word_table
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Either sp or word_table should be provided.")
|
||||||
|
|
||||||
|
utt_time_pairs = []
|
||||||
|
for utt in utt_index_pairs:
|
||||||
|
start = convert_timestamp(
|
||||||
|
frames=[i[0] for i in utt],
|
||||||
|
subsampling_factor=subsampling_factor,
|
||||||
|
frame_shift_ms=frame_shift_ms,
|
||||||
|
)
|
||||||
|
end = convert_timestamp(
|
||||||
|
# The duration in frames is (end_frame_index - start_frame_index + 1)
|
||||||
|
frames=[i[1] + 1 for i in utt],
|
||||||
|
subsampling_factor=subsampling_factor,
|
||||||
|
frame_shift_ms=frame_shift_ms,
|
||||||
|
)
|
||||||
|
utt_time_pairs.append(list(zip(start, end)))
|
||||||
|
|
||||||
|
return utt_time_pairs, utt_words
|
||||||
|
154
test/test_parse_timestamp.py
Executable file
154
test/test_parse_timestamp.py
Executable file
@ -0,0 +1,154 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.utils import parse_bpe_timestamps_and_texts, parse_timestamps_and_texts
|
||||||
|
|
||||||
|
ICEFALL_DIR = Path(__file__).resolve().parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_bpe_timestamps_and_texts():
|
||||||
|
lang_dir = ICEFALL_DIR / "egs/librispeech/ASR/data/lang_bpe_500"
|
||||||
|
if not lang_dir.is_dir():
|
||||||
|
print(f"{lang_dir} does not exist.")
|
||||||
|
return
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(str(lang_dir / "bpe.model"))
|
||||||
|
|
||||||
|
text_1 = "HELLO WORLD"
|
||||||
|
token_ids_1 = sp.encode(text_1, out_type=int)
|
||||||
|
# out_type=str: ['_HE', 'LL', 'O', '_WORLD']
|
||||||
|
# out_type=int: [22, 58, 24, 425]
|
||||||
|
|
||||||
|
# [22, 22, 58, 24, 0, 0, 425, 425, 425, 0, 0]
|
||||||
|
labels_1 = (
|
||||||
|
token_ids_1[0:1] * 2
|
||||||
|
+ token_ids_1[1:3]
|
||||||
|
+ [0] * 2
|
||||||
|
+ token_ids_1[3:4] * 3
|
||||||
|
+ [0] * 2
|
||||||
|
)
|
||||||
|
# [22, 0, 58, 24, 0, 0, 425, 0, 0, 0, 0, -1]
|
||||||
|
aux_labels_1 = (
|
||||||
|
token_ids_1[0:1]
|
||||||
|
+ [0]
|
||||||
|
+ token_ids_1[1:3]
|
||||||
|
+ [0] * 2
|
||||||
|
+ token_ids_1[3:4]
|
||||||
|
+ [0] * 4
|
||||||
|
+ [-1]
|
||||||
|
)
|
||||||
|
fsa_1 = k2.linear_fsa(labels_1)
|
||||||
|
fsa_1.aux_labels = torch.tensor(aux_labels_1).to(torch.int32)
|
||||||
|
|
||||||
|
text_2 = "SAY GOODBYE"
|
||||||
|
token_ids_2 = sp.encode(text_2, out_type=int)
|
||||||
|
# out_type=str: ['_SAY', '_GOOD', 'B', 'Y', 'E']
|
||||||
|
# out_type=int: [289, 286, 41, 16, 11]
|
||||||
|
|
||||||
|
# [289, 0, 0, 286, 286, 41, 16, 11, 0, 0]
|
||||||
|
labels_2 = (
|
||||||
|
token_ids_2[0:1] + [0] * 2 + token_ids_2[1:2] * 2 + token_ids_2[2:5] + [0] * 2
|
||||||
|
)
|
||||||
|
# [289, 0, 0, 286, 0, 41, 16, 11, 0, 0, -1]
|
||||||
|
aux_labels_2 = (
|
||||||
|
token_ids_2[0:1]
|
||||||
|
+ [0] * 2
|
||||||
|
+ token_ids_2[1:2]
|
||||||
|
+ [0]
|
||||||
|
+ token_ids_2[2:5]
|
||||||
|
+ [0] * 2
|
||||||
|
+ [-1]
|
||||||
|
)
|
||||||
|
fsa_2 = k2.linear_fsa(labels_2)
|
||||||
|
fsa_2.aux_labels = torch.tensor(aux_labels_2).to(torch.int32)
|
||||||
|
|
||||||
|
fsa_vec = k2.create_fsa_vec([fsa_1, fsa_2])
|
||||||
|
|
||||||
|
utt_index_pairs, utt_words = parse_bpe_timestamps_and_texts(fsa_vec, sp)
|
||||||
|
assert utt_index_pairs[0] == [(0, 3), (6, 8)], utt_index_pairs[0]
|
||||||
|
assert utt_words[0] == ["HELLO", "WORLD"], utt_words[0]
|
||||||
|
assert utt_index_pairs[1] == [(0, 0), (3, 7)], utt_index_pairs[1]
|
||||||
|
assert utt_words[1] == ["SAY", "GOODBYE"], utt_words[1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_timestamps_and_texts():
|
||||||
|
lang_dir = ICEFALL_DIR / "egs/librispeech/ASR/data/lang_bpe_500"
|
||||||
|
if not lang_dir.is_dir():
|
||||||
|
print(f"{lang_dir} does not exist.")
|
||||||
|
return
|
||||||
|
|
||||||
|
lexicon = Lexicon(lang_dir)
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(str(lang_dir / "bpe.model"))
|
||||||
|
word_table = lexicon.word_table
|
||||||
|
|
||||||
|
text_1 = "HELLO WORLD"
|
||||||
|
token_ids_1 = sp.encode(text_1, out_type=int)
|
||||||
|
# out_type=str: ['_HE', 'LL', 'O', '_WORLD']
|
||||||
|
# out_type=int: [22, 58, 24, 425]
|
||||||
|
word_ids_1 = [word_table[s] for s in text_1.split()] # [79677, 196937]
|
||||||
|
# [22, 22, 58, 24, 0, 0, 425, 425, 425, 0, 0]
|
||||||
|
labels_1 = (
|
||||||
|
token_ids_1[0:1] * 2
|
||||||
|
+ token_ids_1[1:3]
|
||||||
|
+ [0] * 2
|
||||||
|
+ token_ids_1[3:4] * 3
|
||||||
|
+ [0] * 2
|
||||||
|
)
|
||||||
|
# [[79677], [], [], [], [], [], [196937], [], [], [], [], []]
|
||||||
|
aux_labels_1 = [word_ids_1[0:1]] + [[]] * 5 + [word_ids_1[1:2]] + [[]] * 5
|
||||||
|
|
||||||
|
fsa_1 = k2.linear_fsa(labels_1)
|
||||||
|
fsa_1.aux_labels = k2.RaggedTensor(aux_labels_1)
|
||||||
|
|
||||||
|
text_2 = "SAY GOODBYE"
|
||||||
|
token_ids_2 = sp.encode(text_2, out_type=int)
|
||||||
|
# out_type=str: ['_SAY', '_GOOD', 'B', 'Y', 'E']
|
||||||
|
# out_type=int: [289, 286, 41, 16, 11]
|
||||||
|
word_ids_2 = [word_table[s] for s in text_2.split()] # [154967, 72079]
|
||||||
|
# [289, 0, 0, 286, 286, 41, 16, 11, 0, 0]
|
||||||
|
labels_2 = (
|
||||||
|
token_ids_2[0:1] + [0] * 2 + token_ids_2[1:2] * 2 + token_ids_2[2:5] + [0] * 2
|
||||||
|
)
|
||||||
|
# [[154967], [], [], [72079], [], [], [], [], [], [], []]
|
||||||
|
aux_labels_2 = [word_ids_2[0:1]] + [[]] * 2 + [word_ids_2[1:2]] + [[]] * 7
|
||||||
|
|
||||||
|
fsa_2 = k2.linear_fsa(labels_2)
|
||||||
|
fsa_2.aux_labels = k2.RaggedTensor(aux_labels_2)
|
||||||
|
|
||||||
|
fsa_vec = k2.create_fsa_vec([fsa_1, fsa_2])
|
||||||
|
|
||||||
|
utt_index_pairs, utt_words = parse_timestamps_and_texts(fsa_vec, word_table)
|
||||||
|
assert utt_index_pairs[0] == [(0, 3), (6, 8)], utt_index_pairs[0]
|
||||||
|
assert utt_words[0] == ["HELLO", "WORLD"], utt_words[0]
|
||||||
|
assert utt_index_pairs[1] == [(0, 0), (3, 7)], utt_index_pairs[1]
|
||||||
|
assert utt_words[1] == ["SAY", "GOODBYE"], utt_words[1]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_parse_bpe_timestamps_and_texts()
|
||||||
|
test_parse_timestamps_and_texts()
|
Loading…
x
Reference in New Issue
Block a user