mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 11:32:19 +00:00
add parse_fsa_timestamps_and_texts function, test in conformer_ctc3/decode.py
This commit is contained in:
parent
0e4f7c59c2
commit
3e1d14b9f8
@ -96,8 +96,7 @@ from icefall.lexicon import Lexicon
|
|||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
get_texts,
|
get_texts,
|
||||||
get_texts_with_timestamp,
|
parse_fsa_timestamps_and_texts,
|
||||||
parse_hyp_and_timestamp,
|
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts_and_timestamps,
|
store_transcripts_and_timestamps,
|
||||||
str2bool,
|
str2bool,
|
||||||
@ -396,13 +395,8 @@ def decode_one_batch(
|
|||||||
best_path = one_best_decoding(
|
best_path = one_best_decoding(
|
||||||
lattice=lattice, use_double_scores=params.use_double_scores
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
)
|
)
|
||||||
# Note: `best_path.aux_labels` contains token IDs, not word IDs
|
timestamps, hyps = parse_fsa_timestamps_and_texts(
|
||||||
# since we are using H, not HLG here.
|
best_paths=best_path,
|
||||||
#
|
|
||||||
# token_ids is a lit-of-list of IDs
|
|
||||||
res = get_texts_with_timestamp(best_path)
|
|
||||||
hyps, timestamps = parse_hyp_and_timestamp(
|
|
||||||
res=res,
|
|
||||||
sp=bpe_model,
|
sp=bpe_model,
|
||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
frame_shift_ms=params.frame_shift_ms,
|
frame_shift_ms=params.frame_shift_ms,
|
||||||
@ -435,12 +429,11 @@ def decode_one_batch(
|
|||||||
lattice=lattice, use_double_scores=params.use_double_scores
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
)
|
)
|
||||||
key = f"no_rescore_hlg_scale_{params.hlg_scale}"
|
key = f"no_rescore_hlg_scale_{params.hlg_scale}"
|
||||||
res = get_texts_with_timestamp(best_path)
|
timestamps, hyps = parse_fsa_timestamps_and_texts(
|
||||||
hyps, timestamps = parse_hyp_and_timestamp(
|
best_paths=best_path,
|
||||||
res=res,
|
word_table=word_table,
|
||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
frame_shift_ms=params.frame_shift_ms,
|
frame_shift_ms=params.frame_shift_ms,
|
||||||
word_table=word_table,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
best_path = nbest_decoding(
|
best_path = nbest_decoding(
|
||||||
@ -504,7 +497,18 @@ def decode_dataset(
|
|||||||
sos_id: int,
|
sos_id: int,
|
||||||
eos_id: int,
|
eos_id: int,
|
||||||
G: Optional[k2.Fsa] = None,
|
G: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]:
|
) -> Dict[
|
||||||
|
str,
|
||||||
|
List[
|
||||||
|
Tuple[
|
||||||
|
str,
|
||||||
|
List[str],
|
||||||
|
List[str],
|
||||||
|
List[Tuple[float, float]],
|
||||||
|
List[Tuple[float, float]],
|
||||||
|
]
|
||||||
|
],
|
||||||
|
]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -555,7 +559,7 @@ def decode_dataset(
|
|||||||
time = []
|
time = []
|
||||||
if s.alignment is not None and "word" in s.alignment:
|
if s.alignment is not None and "word" in s.alignment:
|
||||||
time = [
|
time = [
|
||||||
aliword.start
|
(aliword.start, aliword.end)
|
||||||
for aliword in s.alignment["word"]
|
for aliword in s.alignment["word"]
|
||||||
if aliword.symbol != ""
|
if aliword.symbol != ""
|
||||||
]
|
]
|
||||||
@ -601,7 +605,15 @@ def save_results(
|
|||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
results_dict: Dict[
|
results_dict: Dict[
|
||||||
str,
|
str,
|
||||||
List[Tuple[List[str], List[str], List[str], List[float], List[float]]],
|
List[
|
||||||
|
Tuple[
|
||||||
|
List[str],
|
||||||
|
List[str],
|
||||||
|
List[str],
|
||||||
|
List[Tuple[float, float]],
|
||||||
|
List[Tuple[float, float]],
|
||||||
|
]
|
||||||
|
],
|
||||||
],
|
],
|
||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
|
@ -454,9 +454,30 @@ def store_transcripts_and_timestamps(
|
|||||||
for cut_id, ref, hyp, time_ref, time_hyp in texts:
|
for cut_id, ref, hyp, time_ref, time_hyp in texts:
|
||||||
print(f"{cut_id}:\tref={ref}", file=f)
|
print(f"{cut_id}:\tref={ref}", file=f)
|
||||||
print(f"{cut_id}:\thyp={hyp}", file=f)
|
print(f"{cut_id}:\thyp={hyp}", file=f)
|
||||||
|
|
||||||
if len(time_ref) > 0:
|
if len(time_ref) > 0:
|
||||||
|
if isinstance(time_ref[0], tuple):
|
||||||
|
# each element is <start, end> pair
|
||||||
|
s = (
|
||||||
|
"["
|
||||||
|
+ ", ".join(["(%0.3f, %.03f)" % (i, j) for (i, j) in time_ref])
|
||||||
|
+ "]"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# each element is a float number
|
||||||
s = "[" + ", ".join(["%0.3f" % i for i in time_ref]) + "]"
|
s = "[" + ", ".join(["%0.3f" % i for i in time_ref]) + "]"
|
||||||
print(f"{cut_id}:\ttimestamp_ref={s}", file=f)
|
print(f"{cut_id}:\ttimestamp_ref={s}", file=f)
|
||||||
|
|
||||||
|
if len(time_hyp) > 0:
|
||||||
|
if isinstance(time_hyp[0], tuple):
|
||||||
|
# each element is <start, end> pair
|
||||||
|
s = (
|
||||||
|
"["
|
||||||
|
+ ", ".join(["(%0.3f, %.03f)" % (i, j) for (i, j) in time_hyp])
|
||||||
|
+ "]"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# each element is a float number
|
||||||
s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]"
|
s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]"
|
||||||
print(f"{cut_id}:\ttimestamp_hyp={s}", file=f)
|
print(f"{cut_id}:\ttimestamp_hyp={s}", file=f)
|
||||||
|
|
||||||
@ -1493,6 +1514,8 @@ def parse_bpe_start_end_pairs(
|
|||||||
end = i
|
end = i
|
||||||
|
|
||||||
if start != -1 and end != -1:
|
if start != -1 and end != -1:
|
||||||
|
if not all([tokens[t] == start_token for t in range(start, end + 1)]):
|
||||||
|
# except the case of all start_token
|
||||||
pairs.append((start, end))
|
pairs.append((start, end))
|
||||||
# Reset start and end
|
# Reset start and end
|
||||||
start = -1
|
start = -1
|
||||||
@ -1554,7 +1577,7 @@ def parse_bpe_timestamps_and_texts(
|
|||||||
# Indicates whether it is the first token, i.e., not-repeat and not-blank.
|
# Indicates whether it is the first token, i.e., not-repeat and not-blank.
|
||||||
is_first_token = [a != 0 for a in all_aux_labels[i]]
|
is_first_token = [a != 0 for a in all_aux_labels[i]]
|
||||||
index_pairs = parse_bpe_start_end_pairs(tokens, is_first_token)
|
index_pairs = parse_bpe_start_end_pairs(tokens, is_first_token)
|
||||||
assert len(index_pairs) == len(words), (len(index_pairs), len(words))
|
assert len(index_pairs) == len(words), (len(index_pairs), len(words), tokens)
|
||||||
utt_index_pairs.append(index_pairs)
|
utt_index_pairs.append(index_pairs)
|
||||||
utt_words.append(words)
|
utt_words.append(words)
|
||||||
|
|
||||||
@ -1628,3 +1651,72 @@ def parse_timestamps_and_texts(
|
|||||||
utt_words.append(words)
|
utt_words.append(words)
|
||||||
|
|
||||||
return utt_index_pairs, utt_words
|
return utt_index_pairs, utt_words
|
||||||
|
|
||||||
|
|
||||||
|
def parse_fsa_timestamps_and_texts(
|
||||||
|
best_paths: k2.Fsa,
|
||||||
|
sp: Optional[spm.SentencePieceProcessor] = None,
|
||||||
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
|
subsampling_factor: int = 4,
|
||||||
|
frame_shift_ms: float = 10,
|
||||||
|
) -> Tuple[List[Tuple[float, float]], List[List[str]]]:
|
||||||
|
"""Parse timestamps (in seconds) and texts for given decoded fsa paths.
|
||||||
|
Currently it supports two case:
|
||||||
|
(1) ctc-decoding, the attribtutes `labels` and `aux_labels`
|
||||||
|
are both BPE tokens. In this case, sp should be provided.
|
||||||
|
(2) HLG-based 1best, the attribtute `labels` is the prediction unit,
|
||||||
|
e.g., phone or BPE tokens; attribute `aux_labels` is the word index.
|
||||||
|
In this case, word_table should be provided.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
best_paths:
|
||||||
|
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
||||||
|
containing multiple FSAs, which is expected to be the result
|
||||||
|
of k2.shortest_path (otherwise the returned values won't
|
||||||
|
be meaningful).
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
|
word_table:
|
||||||
|
The word symbol table.
|
||||||
|
subsampling_factor:
|
||||||
|
The subsampling factor of the model.
|
||||||
|
frame_shift_ms:
|
||||||
|
Frame shift in milliseconds between two contiguous frames.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
utt_time_pairs:
|
||||||
|
A list of pair list. utt_time_pairs[i] is a list of
|
||||||
|
(start-time, end-time) pairs for each word in
|
||||||
|
utterance-i.
|
||||||
|
utt_words:
|
||||||
|
A list of str list. utt_words[i] is a word list of utterence-i.
|
||||||
|
"""
|
||||||
|
if sp is not None:
|
||||||
|
assert word_table is None, "word_table is not needed if sp is provided."
|
||||||
|
utt_index_pairs, utt_words = parse_bpe_timestamps_and_texts(
|
||||||
|
best_paths=best_paths, sp=sp
|
||||||
|
)
|
||||||
|
elif word_table is not None:
|
||||||
|
assert sp is None, "sp is not needed if word_table is provided."
|
||||||
|
utt_index_pairs, utt_words = parse_timestamps_and_texts(
|
||||||
|
best_paths=best_paths, word_table=word_table
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Either sp or word_table should be provided.")
|
||||||
|
|
||||||
|
utt_time_pairs = []
|
||||||
|
for utt in utt_index_pairs:
|
||||||
|
start = convert_timestamp(
|
||||||
|
frames=[i[0] for i in utt],
|
||||||
|
subsampling_factor=subsampling_factor,
|
||||||
|
frame_shift_ms=frame_shift_ms,
|
||||||
|
)
|
||||||
|
end = convert_timestamp(
|
||||||
|
# The duration in frames is (end_frame_index - start_frame_index + 1)
|
||||||
|
frames=[i[1] + 1 for i in utt],
|
||||||
|
subsampling_factor=subsampling_factor,
|
||||||
|
frame_shift_ms=frame_shift_ms,
|
||||||
|
)
|
||||||
|
utt_time_pairs.append(list(zip(start, end)))
|
||||||
|
|
||||||
|
return utt_time_pairs, utt_words
|
||||||
|
Loading…
x
Reference in New Issue
Block a user