mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Get timestamps during decoding (#598)
* print out timestamps during decoding * add word-level alignments * support to compute mean symbol delay with word-level alignments * print variance of symbol delay * update doc * support to compute delay for pruned_transducer_stateless4 * fix bug * add doc
This commit is contained in:
parent
ff3f026381
commit
03668771d7
12
egs/librispeech/ASR/add_alignments.sh
Executable file
12
egs/librispeech/ASR/add_alignments.sh
Executable file
@ -0,0 +1,12 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
alignments_dir=data/alignment
|
||||
cuts_in_dir=data/fbank
|
||||
cuts_out_dir=data/fbank_ali
|
||||
|
||||
python3 ./local/add_alignment_librispeech.py \
|
||||
--alignments-dir $alignments_dir \
|
||||
--cuts-in-dir $cuts_in_dir \
|
||||
--cuts-out-dir $cuts_out_dir
|
196
egs/librispeech/ASR/local/add_alignment_librispeech.py
Executable file
196
egs/librispeech/ASR/local/add_alignment_librispeech.py
Executable file
@ -0,0 +1,196 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file adds alignments from https://github.com/CorentinJ/librispeech-alignments # noqa
|
||||
to the existing fbank features dir (e.g., data/fbank)
|
||||
and save cuts to a new dir (e.g., data/fbank_ali).
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
from lhotse.recipes.librispeech import parse_alignments
|
||||
from lhotse.utils import is_module_available
|
||||
|
||||
LIBRISPEECH_ALIGNMENTS_URL = (
|
||||
"https://drive.google.com/uc?id=1WYfgr31T-PPwMcxuAq09XZfHQO5Mw8fE"
|
||||
)
|
||||
|
||||
DATASET_PARTS = [
|
||||
"dev-clean",
|
||||
"dev-other",
|
||||
"test-clean",
|
||||
"test-other",
|
||||
"train-clean-100",
|
||||
"train-clean-360",
|
||||
"train-other-500",
|
||||
]
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--alignments-dir",
|
||||
type=str,
|
||||
default="data/alignment",
|
||||
help="The dir to save alignments.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cuts-in-dir",
|
||||
type=str,
|
||||
default="data/fbank",
|
||||
help="The dir of the existing cuts without alignments.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cuts-out-dir",
|
||||
type=str,
|
||||
default="data/fbank_ali",
|
||||
help="The dir to save the new cuts with alignments",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def download_alignments(
|
||||
target_dir: str, alignments_url: str = LIBRISPEECH_ALIGNMENTS_URL
|
||||
):
|
||||
"""
|
||||
Download and extract the alignments.
|
||||
|
||||
Note: If you can not access drive.google.com, you could download the file
|
||||
`LibriSpeech-Alignments.zip` from huggingface:
|
||||
https://huggingface.co/Zengwei/librispeech-alignments
|
||||
and extract the zip file manually.
|
||||
|
||||
Args:
|
||||
target_dir:
|
||||
The dir to save alignments.
|
||||
alignments_url:
|
||||
The URL of alignments.
|
||||
"""
|
||||
"""Modified from https://github.com/lhotse-speech/lhotse/blob/master/lhotse/recipes/librispeech.py""" # noqa
|
||||
target_dir = Path(target_dir)
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
completed_detector = target_dir / ".ali_completed"
|
||||
if completed_detector.is_file():
|
||||
logging.info("The alignment files already exist.")
|
||||
return
|
||||
|
||||
ali_zip_path = target_dir / "LibriSpeech-Alignments.zip"
|
||||
if not ali_zip_path.is_file():
|
||||
assert is_module_available(
|
||||
"gdown"
|
||||
), 'To download LibriSpeech alignments, please install "pip install gdown"' # noqa
|
||||
import gdown
|
||||
|
||||
gdown.download(alignments_url, output=str(ali_zip_path))
|
||||
|
||||
with zipfile.ZipFile(str(ali_zip_path)) as f:
|
||||
f.extractall(path=target_dir)
|
||||
completed_detector.touch()
|
||||
|
||||
|
||||
def add_alignment(
|
||||
alignments_dir: str,
|
||||
cuts_in_dir: str = "data/fbank",
|
||||
cuts_out_dir: str = "data/fbank_ali",
|
||||
dataset_parts: List[str] = DATASET_PARTS,
|
||||
):
|
||||
"""
|
||||
Add alignment info to existing cuts.
|
||||
|
||||
Args:
|
||||
alignments_dir:
|
||||
The dir of the alignments.
|
||||
cuts_in_dir:
|
||||
The dir of the existing cuts.
|
||||
cuts_out_dir:
|
||||
The dir to save the new cuts with alignments.
|
||||
dataset_parts:
|
||||
Librispeech parts to add alignments.
|
||||
"""
|
||||
alignments_dir = Path(alignments_dir)
|
||||
cuts_in_dir = Path(cuts_in_dir)
|
||||
cuts_out_dir = Path(cuts_out_dir)
|
||||
cuts_out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for part in dataset_parts:
|
||||
logging.info(f"Processing {part}")
|
||||
|
||||
cuts_in_path = cuts_in_dir / f"librispeech_cuts_{part}.jsonl.gz"
|
||||
if not cuts_in_path.is_file():
|
||||
logging.info(f"{cuts_in_path} does not exist - skipping.")
|
||||
continue
|
||||
cuts_out_path = cuts_out_dir / f"librispeech_cuts_{part}.jsonl.gz"
|
||||
if cuts_out_path.is_file():
|
||||
logging.info(f"{part} already exists - skipping.")
|
||||
continue
|
||||
|
||||
# parse alignments
|
||||
alignments = {}
|
||||
part_ali_dir = alignments_dir / "LibriSpeech" / part
|
||||
for ali_path in part_ali_dir.rglob("*.alignment.txt"):
|
||||
ali = parse_alignments(ali_path)
|
||||
alignments.update(ali)
|
||||
logging.info(
|
||||
f"{part} has {len(alignments.keys())} cuts with alignments."
|
||||
)
|
||||
|
||||
# add alignment attribute and write out
|
||||
cuts_in = load_manifest_lazy(cuts_in_path)
|
||||
with CutSet.open_writer(cuts_out_path) as writer:
|
||||
for cut in cuts_in:
|
||||
for idx, subcut in enumerate(cut.supervisions):
|
||||
origin_id = subcut.id.split("_")[0]
|
||||
if origin_id in alignments:
|
||||
ali = alignments[origin_id]
|
||||
else:
|
||||
logging.info(
|
||||
f"Warning: {origin_id} does not has alignment."
|
||||
)
|
||||
ali = []
|
||||
subcut.alignment = {"word": ali}
|
||||
writer.write(cut, flush=True)
|
||||
|
||||
|
||||
def main():
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
download_alignments(args.alignments_dir)
|
||||
add_alignment(args.alignments_dir, args.cuts_in_dir, args.cuts_out_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -91,6 +91,22 @@ Usage:
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
|
||||
To evaluate symbol delay, you should:
|
||||
(1) Generate cuts with word-time alignments:
|
||||
./local/add_alignment_librispeech.py \
|
||||
--alignments-dir data/alignment \
|
||||
--cuts-in-dir data/fbank \
|
||||
--cuts-out-dir data/fbank_ali
|
||||
(2) Set the argument "--manifest-dir data/fbank_ali" while decoding.
|
||||
For example:
|
||||
./lstm_transducer_stateless3/decode.py \
|
||||
--epoch 40 \
|
||||
--avg 20 \
|
||||
--exp-dir ./lstm_transducer_stateless3/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search \
|
||||
--manifest-dir data/fbank_ali
|
||||
"""
|
||||
|
||||
|
||||
@ -127,10 +143,12 @@ from icefall.checkpoint import (
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
DecodingResults,
|
||||
parse_hyp_and_timestamp,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
store_transcripts_and_timestamps,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
write_error_stats_with_timestamps,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
@ -314,7 +332,7 @@ def decode_one_batch(
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
@ -322,9 +340,11 @@ def decode_one_batch(
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
- value: It is a tuple. `len(value[0])` and `len(value[1])` are both
|
||||
equal to the batch size. `value[0][i]` and `value[1][i]`
|
||||
are the decoding result and timestamps for the i-th utterance
|
||||
in the given batch respectively.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
@ -343,8 +363,8 @@ def decode_one_batch(
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
Return the decoding result and timestamps. See above description for the
|
||||
format of the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
@ -370,10 +390,8 @@ def decode_one_batch(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
res = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
@ -381,11 +399,10 @@ def decode_one_batch(
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
return_timestamps=True,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
hyp_tokens = fast_beam_search_nbest_LG(
|
||||
res = fast_beam_search_nbest_LG(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
@ -395,11 +412,10 @@ def decode_one_batch(
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
return_timestamps=True,
|
||||
)
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
elif params.decoding_method == "fast_beam_search_nbest":
|
||||
hyp_tokens = fast_beam_search_nbest(
|
||||
res = fast_beam_search_nbest(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
@ -409,11 +425,10 @@ def decode_one_batch(
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
return_timestamps=True,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||
res = fast_beam_search_nbest_oracle(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
@ -424,56 +439,67 @@ def decode_one_batch(
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=sp.encode(supervisions["text"]),
|
||||
nbest_scale=params.nbest_scale,
|
||||
return_timestamps=True,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
hyp_tokens = greedy_search_batch(
|
||||
res = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
return_timestamps=True,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
res = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
return_timestamps=True,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
tokens = []
|
||||
timestamps = []
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
res = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
return_timestamps=True,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
res = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
return_timestamps=True,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
tokens.extend(res.tokens)
|
||||
timestamps.extend(res.timestamps)
|
||||
res = DecodingResults(tokens=tokens, timestamps=timestamps)
|
||||
|
||||
hyps, timestamps = parse_hyp_and_timestamp(
|
||||
decoding_method=params.decoding_method,
|
||||
res=res,
|
||||
sp=sp,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
frame_shift_ms=params.frame_shift_ms,
|
||||
word_table=word_table,
|
||||
)
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
return {"greedy_search": (hyps, timestamps)}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
@ -484,9 +510,9 @@ def decode_one_batch(
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
return {key: (hyps, timestamps)}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
return {f"beam_size_{params.beam_size}": (hyps, timestamps)}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
@ -496,7 +522,9 @@ def decode_dataset(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
) -> Dict[
|
||||
str, List[Tuple[str, List[str], List[str], List[float], List[float]]]
|
||||
]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
@ -517,9 +545,12 @@ def decode_dataset(
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
Its value is a list of tuples. Each tuple contains five elements:
|
||||
- cut_id
|
||||
- reference transcript
|
||||
- predicted result
|
||||
- timestamp of reference transcript
|
||||
- timestamp of predicted result
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
@ -538,6 +569,18 @@ def decode_dataset(
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
timestamps_ref = []
|
||||
for cut in batch["supervisions"]["cut"]:
|
||||
for s in cut.supervisions:
|
||||
time = []
|
||||
if s.alignment is not None and "word" in s.alignment:
|
||||
time = [
|
||||
aliword.start
|
||||
for aliword in s.alignment["word"]
|
||||
if aliword.symbol != ""
|
||||
]
|
||||
timestamps_ref.append(time)
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -547,12 +590,18 @@ def decode_dataset(
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
for name, (hyps, timestamps_hyp) in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
assert len(hyps) == len(texts) and len(timestamps_hyp) == len(
|
||||
timestamps_ref
|
||||
)
|
||||
for cut_id, hyp_words, ref_text, time_hyp, time_ref in zip(
|
||||
cut_ids, hyps, texts, timestamps_hyp, timestamps_ref
|
||||
):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
this_batch.append(
|
||||
(cut_id, ref_words, hyp_words, time_ref, time_hyp)
|
||||
)
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
@ -570,15 +619,19 @@ def decode_dataset(
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
results_dict: Dict[
|
||||
str,
|
||||
List[Tuple[List[str], List[str], List[str], List[float], List[float]]],
|
||||
],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
test_set_delays = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
store_transcripts_and_timestamps(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
@ -587,10 +640,11 @@ def save_results(
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
wer, mean_delay, var_delay = write_error_stats_with_timestamps(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
test_set_delays[key] = (mean_delay, var_delay)
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
@ -604,6 +658,19 @@ def save_results(
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0])
|
||||
delays_info = (
|
||||
params.res_dir
|
||||
/ f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(delays_info, "w") as f:
|
||||
print("settings\tsymbol-delay", file=f)
|
||||
for key, val in test_set_delays:
|
||||
print(
|
||||
"{}\tmean: {}s, variance: {}".format(key, val[0], val[1]),
|
||||
file=f,
|
||||
)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
@ -611,6 +678,15 @@ def save_results(
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
s = "\nFor {}, symbol-delay of different settings are:\n".format(
|
||||
test_set_name
|
||||
)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_delays:
|
||||
s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
|
@ -377,6 +377,7 @@ def get_params() -> AttributeDict:
|
||||
"""
|
||||
params = AttributeDict(
|
||||
{
|
||||
"frame_shift_ms": 10.0,
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
"best_train_epoch": -1,
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
@ -25,7 +25,13 @@ from model import Transducer
|
||||
|
||||
from icefall import NgramLm, NgramLmStateCost
|
||||
from icefall.decode import Nbest, one_best_decoding
|
||||
from icefall.utils import add_eos, add_sos, get_texts
|
||||
from icefall.utils import (
|
||||
DecodingResults,
|
||||
add_eos,
|
||||
add_sos,
|
||||
get_texts,
|
||||
get_texts_with_timestamp,
|
||||
)
|
||||
|
||||
|
||||
def fast_beam_search_one_best(
|
||||
@ -37,7 +43,8 @@ def fast_beam_search_one_best(
|
||||
max_states: int,
|
||||
max_contexts: int,
|
||||
temperature: float = 1.0,
|
||||
) -> List[List[int]]:
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
A lattice is first obtained using fast beam search, and then
|
||||
@ -61,8 +68,12 @@ def fast_beam_search_one_best(
|
||||
Max contexts pre stream per frame.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
"""
|
||||
lattice = fast_beam_search(
|
||||
model=model,
|
||||
@ -76,8 +87,11 @@ def fast_beam_search_one_best(
|
||||
)
|
||||
|
||||
best_path = one_best_decoding(lattice)
|
||||
hyps = get_texts(best_path)
|
||||
return hyps
|
||||
|
||||
if not return_timestamps:
|
||||
return get_texts(best_path)
|
||||
else:
|
||||
return get_texts_with_timestamp(best_path)
|
||||
|
||||
|
||||
def fast_beam_search_nbest_LG(
|
||||
@ -92,7 +106,8 @@ def fast_beam_search_nbest_LG(
|
||||
nbest_scale: float = 0.5,
|
||||
use_double_scores: bool = True,
|
||||
temperature: float = 1.0,
|
||||
) -> List[List[int]]:
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
The process to get the results is:
|
||||
@ -129,8 +144,12 @@ def fast_beam_search_nbest_LG(
|
||||
single precision.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
"""
|
||||
lattice = fast_beam_search(
|
||||
model=model,
|
||||
@ -195,9 +214,10 @@ def fast_beam_search_nbest_LG(
|
||||
best_hyp_indexes = ragged_tot_scores.argmax()
|
||||
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
|
||||
return hyps
|
||||
if not return_timestamps:
|
||||
return get_texts(best_path)
|
||||
else:
|
||||
return get_texts_with_timestamp(best_path)
|
||||
|
||||
|
||||
def fast_beam_search_nbest(
|
||||
@ -212,7 +232,8 @@ def fast_beam_search_nbest(
|
||||
nbest_scale: float = 0.5,
|
||||
use_double_scores: bool = True,
|
||||
temperature: float = 1.0,
|
||||
) -> List[List[int]]:
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
The process to get the results is:
|
||||
@ -249,8 +270,12 @@ def fast_beam_search_nbest(
|
||||
single precision.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
"""
|
||||
lattice = fast_beam_search(
|
||||
model=model,
|
||||
@ -279,9 +304,10 @@ def fast_beam_search_nbest(
|
||||
|
||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
|
||||
return hyps
|
||||
if not return_timestamps:
|
||||
return get_texts(best_path)
|
||||
else:
|
||||
return get_texts_with_timestamp(best_path)
|
||||
|
||||
|
||||
def fast_beam_search_nbest_oracle(
|
||||
@ -297,7 +323,8 @@ def fast_beam_search_nbest_oracle(
|
||||
use_double_scores: bool = True,
|
||||
nbest_scale: float = 0.5,
|
||||
temperature: float = 1.0,
|
||||
) -> List[List[int]]:
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
A lattice is first obtained using fast beam search, and then
|
||||
@ -338,8 +365,12 @@ def fast_beam_search_nbest_oracle(
|
||||
yields more unique paths.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
"""
|
||||
lattice = fast_beam_search(
|
||||
model=model,
|
||||
@ -378,8 +409,10 @@ def fast_beam_search_nbest_oracle(
|
||||
|
||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
return hyps
|
||||
if not return_timestamps:
|
||||
return get_texts(best_path)
|
||||
else:
|
||||
return get_texts_with_timestamp(best_path)
|
||||
|
||||
|
||||
def fast_beam_search(
|
||||
@ -469,8 +502,11 @@ def fast_beam_search(
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
|
||||
) -> List[int]:
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
max_sym_per_frame: int,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[int], DecodingResults]:
|
||||
"""Greedy search for a single utterance.
|
||||
Args:
|
||||
model:
|
||||
@ -480,8 +516,12 @@ def greedy_search(
|
||||
max_sym_per_frame:
|
||||
Maximum number of symbols per frame. If it is set to 0, the WER
|
||||
would be 100%.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
@ -507,6 +547,10 @@ def greedy_search(
|
||||
t = 0
|
||||
hyp = [blank_id] * context_size
|
||||
|
||||
# timestamp[i] is the frame index after subsampling
|
||||
# on which hyp[i] is decoded
|
||||
timestamp = []
|
||||
|
||||
# Maximum symbols per utterance.
|
||||
max_sym_per_utt = 1000
|
||||
|
||||
@ -533,6 +577,7 @@ def greedy_search(
|
||||
y = logits.argmax().item()
|
||||
if y not in (blank_id, unk_id):
|
||||
hyp.append(y)
|
||||
timestamp.append(t)
|
||||
decoder_input = torch.tensor(
|
||||
[hyp[-context_size:]], device=device
|
||||
).reshape(1, context_size)
|
||||
@ -547,14 +592,21 @@ def greedy_search(
|
||||
t += 1
|
||||
hyp = hyp[context_size:] # remove blanks
|
||||
|
||||
if not return_timestamps:
|
||||
return hyp
|
||||
else:
|
||||
return DecodingResults(
|
||||
tokens=[hyp],
|
||||
timestamps=[timestamp],
|
||||
)
|
||||
|
||||
|
||||
def greedy_search_batch(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
) -> List[List[int]]:
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
model:
|
||||
@ -564,9 +616,12 @@ def greedy_search_batch(
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,), containing number of valid frames in
|
||||
encoder_out before padding.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
Return a list-of-list of token IDs containing the decoded results.
|
||||
len(ans) equals to encoder_out.size(0).
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
@ -591,6 +646,10 @@ def greedy_search_batch(
|
||||
|
||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||
|
||||
# timestamp[n][i] is the frame index after subsampling
|
||||
# on which hyp[n][i] is decoded
|
||||
timestamps = [[] for _ in range(N)]
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
hyps,
|
||||
device=device,
|
||||
@ -604,7 +663,7 @@ def greedy_search_batch(
|
||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||
|
||||
offset = 0
|
||||
for batch_size in batch_size_list:
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end]
|
||||
@ -626,6 +685,7 @@ def greedy_search_batch(
|
||||
for i, v in enumerate(y):
|
||||
if v not in (blank_id, unk_id):
|
||||
hyps[i].append(v)
|
||||
timestamps[i].append(t)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
@ -640,11 +700,19 @@ def greedy_search_batch(
|
||||
|
||||
sorted_ans = [h[context_size:] for h in hyps]
|
||||
ans = []
|
||||
ans_timestamps = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
ans_timestamps.append(timestamps[unsorted_indices[i]])
|
||||
|
||||
if not return_timestamps:
|
||||
return ans
|
||||
else:
|
||||
return DecodingResults(
|
||||
tokens=ans,
|
||||
timestamps=ans_timestamps,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -657,6 +725,10 @@ class Hypothesis:
|
||||
# It contains only one entry.
|
||||
log_prob: torch.Tensor
|
||||
|
||||
# timestamp[i] is the frame index after subsampling
|
||||
# on which ys[i] is decoded
|
||||
timestamp: List[int]
|
||||
|
||||
state_cost: Optional[NgramLmStateCost] = None
|
||||
|
||||
@property
|
||||
@ -806,7 +878,8 @@ def modified_beam_search(
|
||||
encoder_out_lens: torch.Tensor,
|
||||
beam: int = 4,
|
||||
temperature: float = 1.0,
|
||||
) -> List[List[int]]:
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||
|
||||
Args:
|
||||
@ -821,9 +894,12 @@ def modified_beam_search(
|
||||
Number of active paths during the beam search.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
Return a list-of-list of token IDs. ans[i] is the decoding results
|
||||
for the i-th utterance.
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
"""
|
||||
assert encoder_out.ndim == 3, encoder_out.shape
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
@ -851,6 +927,7 @@ def modified_beam_search(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
timestamp=[],
|
||||
)
|
||||
)
|
||||
|
||||
@ -858,7 +935,7 @@ def modified_beam_search(
|
||||
|
||||
offset = 0
|
||||
finalized_B = []
|
||||
for batch_size in batch_size_list:
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end]
|
||||
@ -936,30 +1013,44 @@ def modified_beam_search(
|
||||
|
||||
new_ys = hyp.ys[:]
|
||||
new_token = topk_token_indexes[k]
|
||||
new_timestamp = hyp.timestamp[:]
|
||||
if new_token not in (blank_id, unk_id):
|
||||
new_ys.append(new_token)
|
||||
new_timestamp.append(t)
|
||||
|
||||
new_log_prob = topk_log_probs[k]
|
||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||
new_hyp = Hypothesis(
|
||||
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
|
||||
)
|
||||
B[i].add(new_hyp)
|
||||
|
||||
B = B + finalized_B
|
||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||
|
||||
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||
sorted_timestamps = [h.timestamp for h in best_hyps]
|
||||
ans = []
|
||||
ans_timestamps = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
|
||||
|
||||
if not return_timestamps:
|
||||
return ans
|
||||
else:
|
||||
return DecodingResults(
|
||||
tokens=ans,
|
||||
timestamps=ans_timestamps,
|
||||
)
|
||||
|
||||
|
||||
def _deprecated_modified_beam_search(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
beam: int = 4,
|
||||
) -> List[int]:
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[int], DecodingResults]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
It decodes only one utterance at a time. We keep it only for reference.
|
||||
@ -974,8 +1065,13 @@ def _deprecated_modified_beam_search(
|
||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||
beam:
|
||||
Beam size.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
"""
|
||||
|
||||
assert encoder_out.ndim == 3
|
||||
@ -995,6 +1091,7 @@ def _deprecated_modified_beam_search(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
timestamp=[],
|
||||
)
|
||||
)
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
@ -1053,17 +1150,24 @@ def _deprecated_modified_beam_search(
|
||||
for i in range(len(topk_hyp_indexes)):
|
||||
hyp = A[topk_hyp_indexes[i]]
|
||||
new_ys = hyp.ys[:]
|
||||
new_timestamp = hyp.timestamp[:]
|
||||
new_token = topk_token_indexes[i]
|
||||
if new_token not in (blank_id, unk_id):
|
||||
new_ys.append(new_token)
|
||||
new_timestamp.append(t)
|
||||
new_log_prob = topk_log_probs[i]
|
||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||
new_hyp = Hypothesis(
|
||||
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
|
||||
)
|
||||
B.add(new_hyp)
|
||||
|
||||
best_hyp = B.get_most_probable(length_norm=True)
|
||||
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
||||
|
||||
if not return_timestamps:
|
||||
return ys
|
||||
else:
|
||||
return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp])
|
||||
|
||||
|
||||
def beam_search(
|
||||
@ -1071,7 +1175,8 @@ def beam_search(
|
||||
encoder_out: torch.Tensor,
|
||||
beam: int = 4,
|
||||
temperature: float = 1.0,
|
||||
) -> List[int]:
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[int], DecodingResults]:
|
||||
"""
|
||||
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
|
||||
|
||||
@ -1086,8 +1191,13 @@ def beam_search(
|
||||
Beam size.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
@ -1114,7 +1224,7 @@ def beam_search(
|
||||
t = 0
|
||||
|
||||
B = HypothesisList()
|
||||
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0))
|
||||
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[]))
|
||||
|
||||
max_sym_per_utt = 20000
|
||||
|
||||
@ -1175,7 +1285,13 @@ def beam_search(
|
||||
new_y_star_log_prob = y_star.log_prob + skip_log_prob
|
||||
|
||||
# ys[:] returns a copy of ys
|
||||
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
|
||||
B.add(
|
||||
Hypothesis(
|
||||
ys=y_star.ys[:],
|
||||
log_prob=new_y_star_log_prob,
|
||||
timestamp=y_star.timestamp[:],
|
||||
)
|
||||
)
|
||||
|
||||
# Second, process other non-blank labels
|
||||
values, indices = log_prob.topk(beam + 1)
|
||||
@ -1184,7 +1300,14 @@ def beam_search(
|
||||
continue
|
||||
new_ys = y_star.ys + [i]
|
||||
new_log_prob = y_star.log_prob + v
|
||||
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
|
||||
new_timestamp = y_star.timestamp + [t]
|
||||
A.add(
|
||||
Hypothesis(
|
||||
ys=new_ys,
|
||||
log_prob=new_log_prob,
|
||||
timestamp=new_timestamp,
|
||||
)
|
||||
)
|
||||
|
||||
# Check whether B contains more than "beam" elements more probable
|
||||
# than the most probable in A
|
||||
@ -1200,7 +1323,11 @@ def beam_search(
|
||||
|
||||
best_hyp = B.get_most_probable(length_norm=True)
|
||||
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
||||
|
||||
if not return_timestamps:
|
||||
return ys
|
||||
else:
|
||||
return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp])
|
||||
|
||||
|
||||
def fast_beam_search_with_nbest_rescoring(
|
||||
@ -1220,7 +1347,8 @@ def fast_beam_search_with_nbest_rescoring(
|
||||
use_double_scores: bool = True,
|
||||
nbest_scale: float = 0.5,
|
||||
temperature: float = 1.0,
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
return_timestamps: bool = False,
|
||||
) -> Dict[str, Union[List[List[int]], DecodingResults]]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
A lattice is first obtained using fast beam search, num_path are selected
|
||||
and rescored using a given language model. The shortest path within the
|
||||
@ -1262,10 +1390,13 @@ def fast_beam_search_with_nbest_rescoring(
|
||||
yields more unique paths.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
Return the decoded result in a dict, where the key has the form
|
||||
'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the
|
||||
ngram LM scale value used during decoding, i.e., 0.1.
|
||||
'ngram_lm_scale_xx' and the value is the decoded results
|
||||
optionally with timestamps. `xx` is the ngram LM scale value
|
||||
used during decoding, i.e., 0.1.
|
||||
"""
|
||||
lattice = fast_beam_search(
|
||||
model=model,
|
||||
@ -1343,16 +1474,18 @@ def fast_beam_search_with_nbest_rescoring(
|
||||
log_semiring=False,
|
||||
)
|
||||
|
||||
ans: Dict[str, List[List[int]]] = {}
|
||||
ans: Dict[str, Union[List[List[int]], DecodingResults]] = {}
|
||||
for s in ngram_lm_scale_list:
|
||||
key = f"ngram_lm_scale_{s}"
|
||||
tot_scores = am_scores.values + s * ngram_lm_scores
|
||||
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||
max_indexes = ragged_tot_scores.argmax()
|
||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||
hyps = get_texts(best_path)
|
||||
|
||||
ans[key] = hyps
|
||||
if not return_timestamps:
|
||||
ans[key] = get_texts(best_path)
|
||||
else:
|
||||
ans[key] = get_texts_with_timestamp(best_path)
|
||||
|
||||
return ans
|
||||
|
||||
@ -1376,7 +1509,8 @@ def fast_beam_search_with_nbest_rnn_rescoring(
|
||||
use_double_scores: bool = True,
|
||||
nbest_scale: float = 0.5,
|
||||
temperature: float = 1.0,
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
return_timestamps: bool = False,
|
||||
) -> Dict[str, Union[List[List[int]], DecodingResults]]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
A lattice is first obtained using fast beam search, num_path are selected
|
||||
and rescored using a given language model and a rnn-lm.
|
||||
@ -1422,10 +1556,13 @@ def fast_beam_search_with_nbest_rnn_rescoring(
|
||||
yields more unique paths.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
Return the decoded result in a dict, where the key has the form
|
||||
'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the
|
||||
ngram LM scale value used during decoding, i.e., 0.1.
|
||||
'ngram_lm_scale_xx' and the value is the decoded results
|
||||
optionally with timestamps. `xx` is the ngram LM scale value
|
||||
used during decoding, i.e., 0.1.
|
||||
"""
|
||||
lattice = fast_beam_search(
|
||||
model=model,
|
||||
@ -1537,9 +1674,11 @@ def fast_beam_search_with_nbest_rnn_rescoring(
|
||||
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||
max_indexes = ragged_tot_scores.argmax()
|
||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||
hyps = get_texts(best_path)
|
||||
|
||||
ans[key] = hyps
|
||||
if not return_timestamps:
|
||||
ans[key] = get_texts(best_path)
|
||||
else:
|
||||
ans[key] = get_texts_with_timestamp(best_path)
|
||||
|
||||
return ans
|
||||
|
||||
|
@ -106,6 +106,22 @@ Usage:
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
|
||||
To evaluate symbol delay, you should:
|
||||
(1) Generate cuts with word-time alignments:
|
||||
./local/add_alignment_librispeech.py \
|
||||
--alignments-dir data/alignment \
|
||||
--cuts-in-dir data/fbank \
|
||||
--cuts-out-dir data/fbank_ali
|
||||
(2) Set the argument "--manifest-dir data/fbank_ali" while decoding.
|
||||
For example:
|
||||
./pruned_transducer_stateless4/decode.py \
|
||||
--epoch 40 \
|
||||
--avg 20 \
|
||||
--exp-dir ./pruned_transducer_stateless4/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search \
|
||||
--manifest-dir data/fbank_ali
|
||||
"""
|
||||
|
||||
|
||||
@ -142,10 +158,12 @@ from icefall.checkpoint import (
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
DecodingResults,
|
||||
parse_hyp_and_timestamp,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
store_transcripts_and_timestamps,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
write_error_stats_with_timestamps,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
@ -318,7 +336,7 @@ def get_parser():
|
||||
"--left-context",
|
||||
type=int,
|
||||
default=64,
|
||||
help="left context can be seen during decoding (in frames after subsampling)",
|
||||
help="left context can be seen during decoding (in frames after subsampling)", # noqa
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -350,7 +368,7 @@ def decode_one_batch(
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
@ -358,9 +376,10 @@ def decode_one_batch(
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
- value: It is a tuple. `len(value[0])` and `len(value[1])` are both
|
||||
equal to the batch size. `value[0][i]` and `value[1][i]`
|
||||
are the decoding result and timestamps for the i-th utterance
|
||||
in the given batch respectively.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
@ -379,8 +398,8 @@ def decode_one_batch(
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
Return the decoding result and timestamps. See above description for the
|
||||
format of the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
@ -412,10 +431,8 @@ def decode_one_batch(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
res = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
@ -423,11 +440,10 @@ def decode_one_batch(
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
return_timestamps=True,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
hyp_tokens = fast_beam_search_nbest_LG(
|
||||
res = fast_beam_search_nbest_LG(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
@ -437,11 +453,10 @@ def decode_one_batch(
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
return_timestamps=True,
|
||||
)
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
elif params.decoding_method == "fast_beam_search_nbest":
|
||||
hyp_tokens = fast_beam_search_nbest(
|
||||
res = fast_beam_search_nbest(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
@ -451,11 +466,10 @@ def decode_one_batch(
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
return_timestamps=True,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||
res = fast_beam_search_nbest_oracle(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
@ -466,56 +480,67 @@ def decode_one_batch(
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=sp.encode(supervisions["text"]),
|
||||
nbest_scale=params.nbest_scale,
|
||||
return_timestamps=True,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
hyp_tokens = greedy_search_batch(
|
||||
res = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
return_timestamps=True,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
res = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
return_timestamps=True,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
tokens = []
|
||||
timestamps = []
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
res = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
return_timestamps=True,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
res = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
return_timestamps=True,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
tokens.extend(res.tokens)
|
||||
timestamps.extend(res.timestamps)
|
||||
res = DecodingResults(tokens=tokens, timestamps=timestamps)
|
||||
|
||||
hyps, timestamps = parse_hyp_and_timestamp(
|
||||
decoding_method=params.decoding_method,
|
||||
res=res,
|
||||
sp=sp,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
frame_shift_ms=params.frame_shift_ms,
|
||||
word_table=word_table,
|
||||
)
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
return {"greedy_search": (hyps, timestamps)}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
@ -526,9 +551,9 @@ def decode_one_batch(
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
return {key: (hyps, timestamps)}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
return {f"beam_size_{params.beam_size}": (hyps, timestamps)}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
@ -538,7 +563,9 @@ def decode_dataset(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
) -> Dict[
|
||||
str, List[Tuple[str, List[str], List[str], List[float], List[float]]]
|
||||
]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
@ -559,9 +586,12 @@ def decode_dataset(
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
Its value is a list of tuples. Each tuple contains five elements:
|
||||
- cut_id
|
||||
- reference transcript
|
||||
- predicted result
|
||||
- timestamp of reference transcript
|
||||
- timestamp of predicted result
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
@ -580,6 +610,18 @@ def decode_dataset(
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
timestamps_ref = []
|
||||
for cut in batch["supervisions"]["cut"]:
|
||||
for s in cut.supervisions:
|
||||
time = []
|
||||
if s.alignment is not None and "word" in s.alignment:
|
||||
time = [
|
||||
aliword.start
|
||||
for aliword in s.alignment["word"]
|
||||
if aliword.symbol != ""
|
||||
]
|
||||
timestamps_ref.append(time)
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -589,12 +631,18 @@ def decode_dataset(
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
for name, (hyps, timestamps_hyp) in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
assert len(hyps) == len(texts) and len(timestamps_hyp) == len(
|
||||
timestamps_ref
|
||||
)
|
||||
for cut_id, hyp_words, ref_text, time_hyp, time_ref in zip(
|
||||
cut_ids, hyps, texts, timestamps_hyp, timestamps_ref
|
||||
):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
this_batch.append(
|
||||
(cut_id, ref_words, hyp_words, time_ref, time_hyp)
|
||||
)
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
@ -612,15 +660,19 @@ def decode_dataset(
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
results_dict: Dict[
|
||||
str,
|
||||
List[Tuple[List[str], List[str], List[str], List[float], List[float]]],
|
||||
],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
test_set_delays = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
store_transcripts_and_timestamps(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
@ -629,10 +681,11 @@ def save_results(
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
wer, mean_delay, var_delay = write_error_stats_with_timestamps(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
test_set_delays[key] = (mean_delay, var_delay)
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
@ -646,6 +699,19 @@ def save_results(
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0])
|
||||
delays_info = (
|
||||
params.res_dir
|
||||
/ f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(delays_info, "w") as f:
|
||||
print("settings\tsymbol-delay", file=f)
|
||||
for key, val in test_set_delays:
|
||||
print(
|
||||
"{}\tmean: {}s, variance: {}".format(key, val[0], val[1]),
|
||||
file=f,
|
||||
)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
@ -653,6 +719,15 @@ def save_results(
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
s = "\nFor {}, symbol-delay of different settings are:\n".format(
|
||||
test_set_name
|
||||
)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_delays:
|
||||
s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
|
@ -386,6 +386,7 @@ def get_params() -> AttributeDict:
|
||||
"""
|
||||
params = AttributeDict(
|
||||
{
|
||||
"frame_shift_ms": 10.0,
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
"best_train_epoch": -1,
|
||||
|
446
icefall/utils.py
446
icefall/utils.py
@ -24,9 +24,10 @@ import re
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, TextIO, Tuple, Union
|
||||
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
|
||||
|
||||
import k2
|
||||
import k2.version
|
||||
@ -248,6 +249,86 @@ def get_texts(
|
||||
return aux_labels.tolist()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodingResults:
|
||||
# Decoded token IDs for each utterance in the batch
|
||||
tokens: List[List[int]]
|
||||
|
||||
# timestamps[i][k] contains the frame number on which tokens[i][k]
|
||||
# is decoded
|
||||
timestamps: List[List[int]]
|
||||
|
||||
# hyps[i] is the recognition results, i.e., word IDs
|
||||
# for the i-th utterance with fast_beam_search_nbest_LG.
|
||||
hyps: Union[List[List[int]], k2.RaggedTensor] = None
|
||||
|
||||
|
||||
def get_tokens_and_timestamps(labels: List[int]) -> Tuple[List[int], List[int]]:
|
||||
tokens = []
|
||||
timestamps = []
|
||||
for i, v in enumerate(labels):
|
||||
if v != 0:
|
||||
tokens.append(v)
|
||||
timestamps.append(i)
|
||||
|
||||
return tokens, timestamps
|
||||
|
||||
|
||||
def get_texts_with_timestamp(
|
||||
best_paths: k2.Fsa, return_ragged: bool = False
|
||||
) -> DecodingResults:
|
||||
"""Extract the texts (as word IDs) and timestamps from the best-path FSAs.
|
||||
Args:
|
||||
best_paths:
|
||||
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
||||
containing multiple FSAs, which is expected to be the result
|
||||
of k2.shortest_path (otherwise the returned values won't
|
||||
be meaningful).
|
||||
return_ragged:
|
||||
True to return a ragged tensor with two axes [utt][word_id].
|
||||
False to return a list-of-list word IDs.
|
||||
Returns:
|
||||
Returns a list of lists of int, containing the label sequences we
|
||||
decoded.
|
||||
"""
|
||||
if isinstance(best_paths.aux_labels, k2.RaggedTensor):
|
||||
# remove 0's and -1's.
|
||||
aux_labels = best_paths.aux_labels.remove_values_leq(0)
|
||||
# TODO: change arcs.shape() to arcs.shape
|
||||
aux_shape = best_paths.arcs.shape().compose(aux_labels.shape)
|
||||
|
||||
# remove the states and arcs axes.
|
||||
aux_shape = aux_shape.remove_axis(1)
|
||||
aux_shape = aux_shape.remove_axis(1)
|
||||
aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values)
|
||||
else:
|
||||
# remove axis corresponding to states.
|
||||
aux_shape = best_paths.arcs.shape().remove_axis(1)
|
||||
aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels)
|
||||
# remove 0's and -1's.
|
||||
aux_labels = aux_labels.remove_values_leq(0)
|
||||
|
||||
assert aux_labels.num_axes == 2
|
||||
|
||||
labels_shape = best_paths.arcs.shape().remove_axis(1)
|
||||
labels_list = k2.RaggedTensor(
|
||||
labels_shape, best_paths.labels.contiguous()
|
||||
).tolist()
|
||||
|
||||
tokens = []
|
||||
timestamps = []
|
||||
for labels in labels_list:
|
||||
token, time = get_tokens_and_timestamps(labels[:-1])
|
||||
tokens.append(token)
|
||||
timestamps.append(time)
|
||||
|
||||
return DecodingResults(
|
||||
tokens=tokens,
|
||||
timestamps=timestamps,
|
||||
hyps=aux_labels if return_ragged else aux_labels.tolist(),
|
||||
)
|
||||
|
||||
|
||||
def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
|
||||
"""Extract labels or aux_labels from the best-path FSAs.
|
||||
|
||||
@ -352,6 +433,33 @@ def store_transcripts(
|
||||
print(f"{cut_id}:\thyp={hyp}", file=f)
|
||||
|
||||
|
||||
def store_transcripts_and_timestamps(
|
||||
filename: Pathlike,
|
||||
texts: Iterable[Tuple[str, List[str], List[str], List[float], List[float]]],
|
||||
) -> None:
|
||||
"""Save predicted results and reference transcripts as well as their timestamps
|
||||
to a file.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
File to save the results to.
|
||||
texts:
|
||||
An iterable of tuples. The first element is the cur_id, the second is
|
||||
the reference transcript and the third element is the predicted result.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
with open(filename, "w") as f:
|
||||
for cut_id, ref, hyp, time_ref, time_hyp in texts:
|
||||
print(f"{cut_id}:\tref={ref}", file=f)
|
||||
print(f"{cut_id}:\thyp={hyp}", file=f)
|
||||
if len(time_ref) > 0:
|
||||
s = "[" + ", ".join(["%0.3f" % i for i in time_ref]) + "]"
|
||||
print(f"{cut_id}:\ttimestamp_ref={s}", file=f)
|
||||
s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]"
|
||||
print(f"{cut_id}:\ttimestamp_hyp={s}", file=f)
|
||||
|
||||
|
||||
def write_error_stats(
|
||||
f: TextIO,
|
||||
test_set_name: str,
|
||||
@ -519,6 +627,211 @@ def write_error_stats(
|
||||
return float(tot_err_rate)
|
||||
|
||||
|
||||
def write_error_stats_with_timestamps(
|
||||
f: TextIO,
|
||||
test_set_name: str,
|
||||
results: List[Tuple[str, List[str], List[str], List[float], List[float]]],
|
||||
enable_log: bool = True,
|
||||
) -> Tuple[float, float, float]:
|
||||
"""Write statistics based on predicted results and reference transcripts
|
||||
as well as their timestamps.
|
||||
|
||||
It will write the following to the given file:
|
||||
|
||||
- WER
|
||||
- number of insertions, deletions, substitutions, corrects and total
|
||||
reference words. For example::
|
||||
|
||||
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
|
||||
reference words (2337 correct)
|
||||
|
||||
- The difference between the reference transcript and predicted result.
|
||||
An instance is given below::
|
||||
|
||||
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
|
||||
|
||||
The above example shows that the reference word is `EDISON`,
|
||||
but it is predicted to `ADDISON` (a substitution error).
|
||||
|
||||
Another example is::
|
||||
|
||||
FOR THE FIRST DAY (SIR->*) I THINK
|
||||
|
||||
The reference word `SIR` is missing in the predicted
|
||||
results (a deletion error).
|
||||
results:
|
||||
An iterable of tuples. The first element is the cur_id, the second is
|
||||
the reference transcript and the third element is the predicted result.
|
||||
enable_log:
|
||||
If True, also print detailed WER to the console.
|
||||
Otherwise, it is written only to the given file.
|
||||
|
||||
Returns:
|
||||
Return total word error rate and mean delay.
|
||||
"""
|
||||
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
||||
ins: Dict[str, int] = defaultdict(int)
|
||||
dels: Dict[str, int] = defaultdict(int)
|
||||
|
||||
# `words` stores counts per word, as follows:
|
||||
# corr, ref_sub, hyp_sub, ins, dels
|
||||
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
||||
num_corr = 0
|
||||
ERR = "*"
|
||||
# Compute mean alignment delay on the correct words
|
||||
all_delay = []
|
||||
for cut_id, ref, hyp, time_ref, time_hyp in results:
|
||||
ali = kaldialign.align(ref, hyp, ERR)
|
||||
has_time_ref = len(time_ref) > 0
|
||||
if has_time_ref:
|
||||
# pointer to timestamp_hyp
|
||||
p_hyp = 0
|
||||
# pointer to timestamp_ref
|
||||
p_ref = 0
|
||||
for ref_word, hyp_word in ali:
|
||||
if ref_word == ERR:
|
||||
ins[hyp_word] += 1
|
||||
words[hyp_word][3] += 1
|
||||
if has_time_ref:
|
||||
p_hyp += 1
|
||||
elif hyp_word == ERR:
|
||||
dels[ref_word] += 1
|
||||
words[ref_word][4] += 1
|
||||
if has_time_ref:
|
||||
p_ref += 1
|
||||
elif hyp_word != ref_word:
|
||||
subs[(ref_word, hyp_word)] += 1
|
||||
words[ref_word][1] += 1
|
||||
words[hyp_word][2] += 1
|
||||
if has_time_ref:
|
||||
p_hyp += 1
|
||||
p_ref += 1
|
||||
else:
|
||||
words[ref_word][0] += 1
|
||||
num_corr += 1
|
||||
if has_time_ref:
|
||||
all_delay.append(time_hyp[p_hyp] - time_ref[p_ref])
|
||||
p_hyp += 1
|
||||
p_ref += 1
|
||||
if has_time_ref:
|
||||
assert p_hyp == len(hyp), (p_hyp, len(hyp))
|
||||
assert p_ref == len(ref), (p_ref, len(ref))
|
||||
|
||||
ref_len = sum([len(r) for _, r, _, _, _ in results])
|
||||
sub_errs = sum(subs.values())
|
||||
ins_errs = sum(ins.values())
|
||||
del_errs = sum(dels.values())
|
||||
tot_errs = sub_errs + ins_errs + del_errs
|
||||
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
||||
|
||||
mean_delay = "inf"
|
||||
var_delay = "inf"
|
||||
num_delay = len(all_delay)
|
||||
if num_delay > 0:
|
||||
mean_delay = sum(all_delay) / num_delay
|
||||
var_delay = sum([(i - mean_delay) ** 2 for i in all_delay]) / num_delay
|
||||
mean_delay = "%.3f" % mean_delay
|
||||
var_delay = "%.3f" % var_delay
|
||||
|
||||
if enable_log:
|
||||
logging.info(
|
||||
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
||||
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
||||
f"{del_errs} del, {sub_errs} sub ]"
|
||||
)
|
||||
logging.info(
|
||||
f"[{test_set_name}] %symbol-delay mean: {mean_delay}s, variance: {var_delay} " # noqa
|
||||
f"computed on {num_delay} correct words"
|
||||
)
|
||||
|
||||
print(f"%WER = {tot_err_rate}", file=f)
|
||||
print(
|
||||
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
|
||||
f"{sub_errs} substitutions, over {ref_len} reference "
|
||||
f"words ({num_corr} correct)",
|
||||
file=f,
|
||||
)
|
||||
print(
|
||||
"Search below for sections starting with PER-UTT DETAILS:, "
|
||||
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
|
||||
file=f,
|
||||
)
|
||||
|
||||
print("", file=f)
|
||||
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
||||
for cut_id, ref, hyp, _, _ in results:
|
||||
ali = kaldialign.align(ref, hyp, ERR)
|
||||
combine_successive_errors = True
|
||||
if combine_successive_errors:
|
||||
ali = [[[x], [y]] for x, y in ali]
|
||||
for i in range(len(ali) - 1):
|
||||
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
|
||||
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
|
||||
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
|
||||
ali[i] = [[], []]
|
||||
ali = [
|
||||
[
|
||||
list(filter(lambda a: a != ERR, x)),
|
||||
list(filter(lambda a: a != ERR, y)),
|
||||
]
|
||||
for x, y in ali
|
||||
]
|
||||
ali = list(filter(lambda x: x != [[], []], ali))
|
||||
ali = [
|
||||
[
|
||||
ERR if x == [] else " ".join(x),
|
||||
ERR if y == [] else " ".join(y),
|
||||
]
|
||||
for x, y in ali
|
||||
]
|
||||
|
||||
print(
|
||||
f"{cut_id}:\t"
|
||||
+ " ".join(
|
||||
(
|
||||
ref_word
|
||||
if ref_word == hyp_word
|
||||
else f"({ref_word}->{hyp_word})"
|
||||
for ref_word, hyp_word in ali
|
||||
)
|
||||
),
|
||||
file=f,
|
||||
)
|
||||
|
||||
print("", file=f)
|
||||
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
||||
|
||||
for count, (ref, hyp) in sorted(
|
||||
[(v, k) for k, v in subs.items()], reverse=True
|
||||
):
|
||||
print(f"{count} {ref} -> {hyp}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print("DELETIONS: count ref", file=f)
|
||||
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
|
||||
print(f"{count} {ref}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print("INSERTIONS: count hyp", file=f)
|
||||
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
|
||||
print(f"{count} {hyp}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print(
|
||||
"PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f
|
||||
)
|
||||
for _, word, counts in sorted(
|
||||
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
||||
):
|
||||
(corr, ref_sub, hyp_sub, ins, dels) = counts
|
||||
tot_errs = ref_sub + hyp_sub + ins + dels
|
||||
ref_count = corr + ref_sub + dels
|
||||
hyp_count = corr + hyp_sub + ins
|
||||
|
||||
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||
return float(tot_err_rate), float(mean_delay), float(var_delay)
|
||||
|
||||
|
||||
class MetricsTracker(collections.defaultdict):
|
||||
def __init__(self):
|
||||
# Passing the type 'int' to the base-class constructor
|
||||
@ -978,6 +1291,137 @@ def display_and_save_batch(
|
||||
logging.info(f"num tokens: {num_tokens}")
|
||||
|
||||
|
||||
def convert_timestamp(
|
||||
frames: List[int],
|
||||
subsampling_factor: int,
|
||||
frame_shift_ms: float = 10,
|
||||
) -> List[float]:
|
||||
"""Convert frame numbers to time (in seconds) given subsampling factor
|
||||
and frame shift (in milliseconds).
|
||||
|
||||
Args:
|
||||
frames:
|
||||
A list of frame numbers after subsampling.
|
||||
subsampling_factor:
|
||||
The subsampling factor of the model.
|
||||
frame_shift_ms:
|
||||
Frame shift in milliseconds between two contiguous frames.
|
||||
Return:
|
||||
Return the time in seconds corresponding to each given frame.
|
||||
"""
|
||||
frame_shift = frame_shift_ms / 1000.0
|
||||
time = []
|
||||
for f in frames:
|
||||
time.append(f * subsampling_factor * frame_shift)
|
||||
|
||||
return time
|
||||
|
||||
|
||||
def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]:
|
||||
"""
|
||||
Parse timestamp of each word.
|
||||
|
||||
Args:
|
||||
tokens:
|
||||
List of tokens.
|
||||
timestamp:
|
||||
List of timestamp of each token.
|
||||
|
||||
Returns:
|
||||
List of timestamp of each word.
|
||||
"""
|
||||
start_token = b"\xe2\x96\x81".decode() # '_'
|
||||
assert len(tokens) == len(timestamp)
|
||||
ans = []
|
||||
for i in range(len(tokens)):
|
||||
flag = False
|
||||
if i == 0 or tokens[i].startswith(start_token):
|
||||
flag = True
|
||||
if len(tokens[i]) == 1 and tokens[i].startswith(start_token):
|
||||
# tokens[i] == start_token
|
||||
if i == len(tokens) - 1:
|
||||
# it is the last token
|
||||
flag = False
|
||||
elif tokens[i + 1].startswith(start_token):
|
||||
# the next token also starts with start_token
|
||||
flag = False
|
||||
if flag:
|
||||
ans.append(timestamp[i])
|
||||
return ans
|
||||
|
||||
|
||||
def parse_hyp_and_timestamp(
|
||||
res: DecodingResults,
|
||||
decoding_method: str,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
subsampling_factor: int,
|
||||
frame_shift_ms: float = 10,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
) -> Tuple[List[List[str]], List[List[float]]]:
|
||||
"""Parse hypothesis and timestamp.
|
||||
|
||||
Args:
|
||||
res:
|
||||
A DecodingResults object.
|
||||
decoding_method:
|
||||
Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
sp:
|
||||
The BPE model.
|
||||
subsampling_factor:
|
||||
The integer subsampling factor.
|
||||
frame_shift_ms:
|
||||
The float frame shift used for feature extraction.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
|
||||
Returns:
|
||||
Return a list of hypothesis and timestamp.
|
||||
"""
|
||||
assert decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_nbest",
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
)
|
||||
|
||||
hyps = []
|
||||
timestamps = []
|
||||
|
||||
N = len(res.tokens)
|
||||
assert len(res.timestamps) == N
|
||||
use_word_table = False
|
||||
if decoding_method == "fast_beam_search_nbest_LG":
|
||||
assert word_table is not None
|
||||
use_word_table = True
|
||||
|
||||
for i in range(N):
|
||||
tokens = sp.id_to_piece(res.tokens[i])
|
||||
if use_word_table:
|
||||
words = [word_table[i] for i in res.hyps[i]]
|
||||
else:
|
||||
words = sp.decode_pieces(tokens).split()
|
||||
time = convert_timestamp(
|
||||
res.timestamps[i], subsampling_factor, frame_shift_ms
|
||||
)
|
||||
time = parse_timestamp(tokens, time)
|
||||
assert len(time) == len(words), (tokens, words)
|
||||
|
||||
hyps.append(words)
|
||||
timestamps.append(time)
|
||||
|
||||
return hyps, timestamps
|
||||
|
||||
|
||||
# `is_module_available` is copied from
|
||||
# https://github.com/pytorch/audio/blob/6bad3a66a7a1c7cc05755e9ee5931b7391d2b94c/torchaudio/_internal/module_utils.py#L9
|
||||
def is_module_available(*modules: str) -> bool:
|
||||
|
Loading…
x
Reference in New Issue
Block a user