Get timestamps during decoding (#598)

* print out timestamps during decoding

* add word-level alignments

* support to compute mean symbol delay with word-level alignments

* print variance of symbol delay

* update doc

* support to compute delay for pruned_transducer_stateless4

* fix bug

* add doc
This commit is contained in:
Zengwei Yao 2022-11-01 10:24:00 +08:00 committed by GitHub
parent ff3f026381
commit 03668771d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 1094 additions and 150 deletions

View File

@ -0,0 +1,12 @@
#!/usr/bin/env bash
set -eou pipefail
alignments_dir=data/alignment
cuts_in_dir=data/fbank
cuts_out_dir=data/fbank_ali
python3 ./local/add_alignment_librispeech.py \
--alignments-dir $alignments_dir \
--cuts-in-dir $cuts_in_dir \
--cuts-out-dir $cuts_out_dir

View File

@ -0,0 +1,196 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file adds alignments from https://github.com/CorentinJ/librispeech-alignments # noqa
to the existing fbank features dir (e.g., data/fbank)
and save cuts to a new dir (e.g., data/fbank_ali).
"""
import argparse
import logging
import zipfile
from pathlib import Path
from typing import List
from lhotse import CutSet, load_manifest_lazy
from lhotse.recipes.librispeech import parse_alignments
from lhotse.utils import is_module_available
LIBRISPEECH_ALIGNMENTS_URL = (
"https://drive.google.com/uc?id=1WYfgr31T-PPwMcxuAq09XZfHQO5Mw8fE"
)
DATASET_PARTS = [
"dev-clean",
"dev-other",
"test-clean",
"test-other",
"train-clean-100",
"train-clean-360",
"train-other-500",
]
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--alignments-dir",
type=str,
default="data/alignment",
help="The dir to save alignments.",
)
parser.add_argument(
"--cuts-in-dir",
type=str,
default="data/fbank",
help="The dir of the existing cuts without alignments.",
)
parser.add_argument(
"--cuts-out-dir",
type=str,
default="data/fbank_ali",
help="The dir to save the new cuts with alignments",
)
return parser
def download_alignments(
target_dir: str, alignments_url: str = LIBRISPEECH_ALIGNMENTS_URL
):
"""
Download and extract the alignments.
Note: If you can not access drive.google.com, you could download the file
`LibriSpeech-Alignments.zip` from huggingface:
https://huggingface.co/Zengwei/librispeech-alignments
and extract the zip file manually.
Args:
target_dir:
The dir to save alignments.
alignments_url:
The URL of alignments.
"""
"""Modified from https://github.com/lhotse-speech/lhotse/blob/master/lhotse/recipes/librispeech.py""" # noqa
target_dir = Path(target_dir)
target_dir.mkdir(parents=True, exist_ok=True)
completed_detector = target_dir / ".ali_completed"
if completed_detector.is_file():
logging.info("The alignment files already exist.")
return
ali_zip_path = target_dir / "LibriSpeech-Alignments.zip"
if not ali_zip_path.is_file():
assert is_module_available(
"gdown"
), 'To download LibriSpeech alignments, please install "pip install gdown"' # noqa
import gdown
gdown.download(alignments_url, output=str(ali_zip_path))
with zipfile.ZipFile(str(ali_zip_path)) as f:
f.extractall(path=target_dir)
completed_detector.touch()
def add_alignment(
alignments_dir: str,
cuts_in_dir: str = "data/fbank",
cuts_out_dir: str = "data/fbank_ali",
dataset_parts: List[str] = DATASET_PARTS,
):
"""
Add alignment info to existing cuts.
Args:
alignments_dir:
The dir of the alignments.
cuts_in_dir:
The dir of the existing cuts.
cuts_out_dir:
The dir to save the new cuts with alignments.
dataset_parts:
Librispeech parts to add alignments.
"""
alignments_dir = Path(alignments_dir)
cuts_in_dir = Path(cuts_in_dir)
cuts_out_dir = Path(cuts_out_dir)
cuts_out_dir.mkdir(parents=True, exist_ok=True)
for part in dataset_parts:
logging.info(f"Processing {part}")
cuts_in_path = cuts_in_dir / f"librispeech_cuts_{part}.jsonl.gz"
if not cuts_in_path.is_file():
logging.info(f"{cuts_in_path} does not exist - skipping.")
continue
cuts_out_path = cuts_out_dir / f"librispeech_cuts_{part}.jsonl.gz"
if cuts_out_path.is_file():
logging.info(f"{part} already exists - skipping.")
continue
# parse alignments
alignments = {}
part_ali_dir = alignments_dir / "LibriSpeech" / part
for ali_path in part_ali_dir.rglob("*.alignment.txt"):
ali = parse_alignments(ali_path)
alignments.update(ali)
logging.info(
f"{part} has {len(alignments.keys())} cuts with alignments."
)
# add alignment attribute and write out
cuts_in = load_manifest_lazy(cuts_in_path)
with CutSet.open_writer(cuts_out_path) as writer:
for cut in cuts_in:
for idx, subcut in enumerate(cut.supervisions):
origin_id = subcut.id.split("_")[0]
if origin_id in alignments:
ali = alignments[origin_id]
else:
logging.info(
f"Warning: {origin_id} does not has alignment."
)
ali = []
subcut.alignment = {"word": ali}
writer.write(cut, flush=True)
def main():
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
download_alignments(args.alignments_dir)
add_alignment(args.alignments_dir, args.cuts_in_dir, args.cuts_out_dir)
if __name__ == "__main__":
main()

View File

@ -91,6 +91,22 @@ Usage:
--beam 20.0 \ --beam 20.0 \
--max-contexts 8 \ --max-contexts 8 \
--max-states 64 --max-states 64
To evaluate symbol delay, you should:
(1) Generate cuts with word-time alignments:
./local/add_alignment_librispeech.py \
--alignments-dir data/alignment \
--cuts-in-dir data/fbank \
--cuts-out-dir data/fbank_ali
(2) Set the argument "--manifest-dir data/fbank_ali" while decoding.
For example:
./lstm_transducer_stateless3/decode.py \
--epoch 40 \
--avg 20 \
--exp-dir ./lstm_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method greedy_search \
--manifest-dir data/fbank_ali
""" """
@ -127,10 +143,12 @@ from icefall.checkpoint import (
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
DecodingResults,
parse_hyp_and_timestamp,
setup_logger, setup_logger,
store_transcripts, store_transcripts_and_timestamps,
str2bool, str2bool,
write_error_stats, write_error_stats_with_timestamps,
) )
LOG_EPS = math.log(1e-10) LOG_EPS = math.log(1e-10)
@ -314,7 +332,7 @@ def decode_one_batch(
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -322,9 +340,11 @@ def decode_one_batch(
if greedy_search is used, it would be "greedy_search" if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be If beam search with a beam size of 7 is used, it would be
"beam_7" "beam_7"
- value: It contains the decoding result. `len(value)` equals to - value: It is a tuple. `len(value[0])` and `len(value[1])` are both
batch size. `value[i]` is the decoding result for the i-th equal to the batch size. `value[0][i]` and `value[1][i]`
utterance in the given batch. are the decoding result and timestamps for the i-th utterance
in the given batch respectively.
Args: Args:
params: params:
It's the return value of :func:`get_params`. It's the return value of :func:`get_params`.
@ -343,8 +363,8 @@ def decode_one_batch(
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result and timestamps. See above description for the
the returned dict. format of the returned dict.
""" """
device = next(model.parameters()).device device = next(model.parameters()).device
feature = batch["inputs"] feature = batch["inputs"]
@ -370,10 +390,8 @@ def decode_one_batch(
x=feature, x_lens=feature_lens x=feature, x_lens=feature_lens
) )
hyps = []
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best( res = fast_beam_search_one_best(
model=model, model=model,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -381,11 +399,10 @@ def decode_one_batch(
beam=params.beam, beam=params.beam,
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
return_timestamps=True,
) )
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG": elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG( res = fast_beam_search_nbest_LG(
model=model, model=model,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -395,11 +412,10 @@ def decode_one_batch(
max_states=params.max_states, max_states=params.max_states,
num_paths=params.num_paths, num_paths=params.num_paths,
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
return_timestamps=True,
) )
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest": elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest( res = fast_beam_search_nbest(
model=model, model=model,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -409,11 +425,10 @@ def decode_one_batch(
max_states=params.max_states, max_states=params.max_states,
num_paths=params.num_paths, num_paths=params.num_paths,
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
return_timestamps=True,
) )
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle": elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle( res = fast_beam_search_nbest_oracle(
model=model, model=model,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -424,56 +439,67 @@ def decode_one_batch(
num_paths=params.num_paths, num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]), ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
return_timestamps=True,
) )
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif ( elif (
params.decoding_method == "greedy_search" params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1 and params.max_sym_per_frame == 1
): ):
hyp_tokens = greedy_search_batch( res = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
return_timestamps=True,
) )
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search( res = modified_beam_search(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
return_timestamps=True,
) )
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
tokens = []
timestamps = []
for i in range(batch_size): for i in range(batch_size):
# fmt: off # fmt: off
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
# fmt: on # fmt: on
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
hyp = greedy_search( res = greedy_search(
model=model, model=model,
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame, max_sym_per_frame=params.max_sym_per_frame,
return_timestamps=True,
) )
elif params.decoding_method == "beam_search": elif params.decoding_method == "beam_search":
hyp = beam_search( res = beam_search(
model=model, model=model,
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
beam=params.beam_size, beam=params.beam_size,
return_timestamps=True,
) )
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
) )
hyps.append(sp.decode(hyp).split()) tokens.extend(res.tokens)
timestamps.extend(res.timestamps)
res = DecodingResults(tokens=tokens, timestamps=timestamps)
hyps, timestamps = parse_hyp_and_timestamp(
decoding_method=params.decoding_method,
res=res,
sp=sp,
subsampling_factor=params.subsampling_factor,
frame_shift_ms=params.frame_shift_ms,
word_table=word_table,
)
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": (hyps, timestamps)}
elif "fast_beam_search" in params.decoding_method: elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_" key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_" key += f"max_contexts_{params.max_contexts}_"
@ -484,9 +510,9 @@ def decode_one_batch(
if "LG" in params.decoding_method: if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}" key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps} return {key: (hyps, timestamps)}
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": (hyps, timestamps)}
def decode_dataset( def decode_dataset(
@ -496,7 +522,9 @@ def decode_dataset(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: ) -> Dict[
str, List[Tuple[str, List[str], List[str], List[float], List[float]]]
]:
"""Decode dataset. """Decode dataset.
Args: Args:
@ -517,9 +545,12 @@ def decode_dataset(
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements: Its value is a list of tuples. Each tuple contains five elements:
The first is the reference transcript, and the second is the - cut_id
predicted result. - reference transcript
- predicted result
- timestamp of reference transcript
- timestamp of predicted result
""" """
num_cuts = 0 num_cuts = 0
@ -538,6 +569,18 @@ def decode_dataset(
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
timestamps_ref = []
for cut in batch["supervisions"]["cut"]:
for s in cut.supervisions:
time = []
if s.alignment is not None and "word" in s.alignment:
time = [
aliword.start
for aliword in s.alignment["word"]
if aliword.symbol != ""
]
timestamps_ref.append(time)
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
model=model, model=model,
@ -547,12 +590,18 @@ def decode_dataset(
batch=batch, batch=batch,
) )
for name, hyps in hyps_dict.items(): for name, (hyps, timestamps_hyp) in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts) and len(timestamps_hyp) == len(
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): timestamps_ref
)
for cut_id, hyp_words, ref_text, time_hyp, time_ref in zip(
cut_ids, hyps, texts, timestamps_hyp, timestamps_ref
):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words)) this_batch.append(
(cut_id, ref_words, hyp_words, time_ref, time_hyp)
)
results[name].extend(this_batch) results[name].extend(this_batch)
@ -570,15 +619,19 @@ def decode_dataset(
def save_results( def save_results(
params: AttributeDict, params: AttributeDict,
test_set_name: str, test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], results_dict: Dict[
str,
List[Tuple[List[str], List[str], List[str], List[float], List[float]]],
],
): ):
test_set_wers = dict() test_set_wers = dict()
test_set_delays = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results) results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts_and_timestamps(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
@ -587,10 +640,11 @@ def save_results(
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer, mean_delay, var_delay = write_error_stats_with_timestamps(
f, f"{test_set_name}-{key}", results, enable_log=True f, f"{test_set_name}-{key}", results, enable_log=True
) )
test_set_wers[key] = wer test_set_wers[key] = wer
test_set_delays[key] = (mean_delay, var_delay)
logging.info("Wrote detailed error stats to {}".format(errs_filename)) logging.info("Wrote detailed error stats to {}".format(errs_filename))
@ -604,6 +658,19 @@ def save_results(
for key, val in test_set_wers: for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f) print("{}\t{}".format(key, val), file=f)
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0])
delays_info = (
params.res_dir
/ f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(delays_info, "w") as f:
print("settings\tsymbol-delay", file=f)
for key, val in test_set_delays:
print(
"{}\tmean: {}s, variance: {}".format(key, val[0], val[1]),
file=f,
)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name) s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name) note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers: for key, val in test_set_wers:
@ -611,6 +678,15 @@ def save_results(
note = "" note = ""
logging.info(s) logging.info(s)
s = "\nFor {}, symbol-delay of different settings are:\n".format(
test_set_name
)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_delays:
s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note)
note = ""
logging.info(s)
@torch.no_grad() @torch.no_grad()
def main(): def main():

View File

@ -377,6 +377,7 @@ def get_params() -> AttributeDict:
""" """
params = AttributeDict( params = AttributeDict(
{ {
"frame_shift_ms": 10.0,
"best_train_loss": float("inf"), "best_train_loss": float("inf"),
"best_valid_loss": float("inf"), "best_valid_loss": float("inf"),
"best_train_epoch": -1, "best_train_epoch": -1,

View File

@ -16,7 +16,7 @@
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional from typing import Dict, List, Optional, Union
import k2 import k2
import sentencepiece as spm import sentencepiece as spm
@ -25,7 +25,13 @@ from model import Transducer
from icefall import NgramLm, NgramLmStateCost from icefall import NgramLm, NgramLmStateCost
from icefall.decode import Nbest, one_best_decoding from icefall.decode import Nbest, one_best_decoding
from icefall.utils import add_eos, add_sos, get_texts from icefall.utils import (
DecodingResults,
add_eos,
add_sos,
get_texts,
get_texts_with_timestamp,
)
def fast_beam_search_one_best( def fast_beam_search_one_best(
@ -37,7 +43,8 @@ def fast_beam_search_one_best(
max_states: int, max_states: int,
max_contexts: int, max_contexts: int,
temperature: float = 1.0, temperature: float = 1.0,
) -> List[List[int]]: return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using fast beam search, and then A lattice is first obtained using fast beam search, and then
@ -61,8 +68,12 @@ def fast_beam_search_one_best(
Max contexts pre stream per frame. Max contexts pre stream per frame.
temperature: temperature:
Softmax temperature. Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
Return the decoded result. If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
""" """
lattice = fast_beam_search( lattice = fast_beam_search(
model=model, model=model,
@ -76,8 +87,11 @@ def fast_beam_search_one_best(
) )
best_path = one_best_decoding(lattice) best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
return hyps if not return_timestamps:
return get_texts(best_path)
else:
return get_texts_with_timestamp(best_path)
def fast_beam_search_nbest_LG( def fast_beam_search_nbest_LG(
@ -92,7 +106,8 @@ def fast_beam_search_nbest_LG(
nbest_scale: float = 0.5, nbest_scale: float = 0.5,
use_double_scores: bool = True, use_double_scores: bool = True,
temperature: float = 1.0, temperature: float = 1.0,
) -> List[List[int]]: return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
The process to get the results is: The process to get the results is:
@ -129,8 +144,12 @@ def fast_beam_search_nbest_LG(
single precision. single precision.
temperature: temperature:
Softmax temperature. Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
Return the decoded result. If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
""" """
lattice = fast_beam_search( lattice = fast_beam_search(
model=model, model=model,
@ -195,9 +214,10 @@ def fast_beam_search_nbest_LG(
best_hyp_indexes = ragged_tot_scores.argmax() best_hyp_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
hyps = get_texts(best_path) if not return_timestamps:
return get_texts(best_path)
return hyps else:
return get_texts_with_timestamp(best_path)
def fast_beam_search_nbest( def fast_beam_search_nbest(
@ -212,7 +232,8 @@ def fast_beam_search_nbest(
nbest_scale: float = 0.5, nbest_scale: float = 0.5,
use_double_scores: bool = True, use_double_scores: bool = True,
temperature: float = 1.0, temperature: float = 1.0,
) -> List[List[int]]: return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
The process to get the results is: The process to get the results is:
@ -249,8 +270,12 @@ def fast_beam_search_nbest(
single precision. single precision.
temperature: temperature:
Softmax temperature. Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
Return the decoded result. If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
""" """
lattice = fast_beam_search( lattice = fast_beam_search(
model=model, model=model,
@ -279,9 +304,10 @@ def fast_beam_search_nbest(
best_path = k2.index_fsa(nbest.fsa, max_indexes) best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path) if not return_timestamps:
return get_texts(best_path)
return hyps else:
return get_texts_with_timestamp(best_path)
def fast_beam_search_nbest_oracle( def fast_beam_search_nbest_oracle(
@ -297,7 +323,8 @@ def fast_beam_search_nbest_oracle(
use_double_scores: bool = True, use_double_scores: bool = True,
nbest_scale: float = 0.5, nbest_scale: float = 0.5,
temperature: float = 1.0, temperature: float = 1.0,
) -> List[List[int]]: return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using fast beam search, and then A lattice is first obtained using fast beam search, and then
@ -338,8 +365,12 @@ def fast_beam_search_nbest_oracle(
yields more unique paths. yields more unique paths.
temperature: temperature:
Softmax temperature. Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
Return the decoded result. If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
""" """
lattice = fast_beam_search( lattice = fast_beam_search(
model=model, model=model,
@ -378,8 +409,10 @@ def fast_beam_search_nbest_oracle(
best_path = k2.index_fsa(nbest.fsa, max_indexes) best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path) if not return_timestamps:
return hyps return get_texts(best_path)
else:
return get_texts_with_timestamp(best_path)
def fast_beam_search( def fast_beam_search(
@ -469,8 +502,11 @@ def fast_beam_search(
def greedy_search( def greedy_search(
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int model: Transducer,
) -> List[int]: encoder_out: torch.Tensor,
max_sym_per_frame: int,
return_timestamps: bool = False,
) -> Union[List[int], DecodingResults]:
"""Greedy search for a single utterance. """Greedy search for a single utterance.
Args: Args:
model: model:
@ -480,8 +516,12 @@ def greedy_search(
max_sym_per_frame: max_sym_per_frame:
Maximum number of symbols per frame. If it is set to 0, the WER Maximum number of symbols per frame. If it is set to 0, the WER
would be 100%. would be 100%.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
Return the decoded result. If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
""" """
assert encoder_out.ndim == 3 assert encoder_out.ndim == 3
@ -507,6 +547,10 @@ def greedy_search(
t = 0 t = 0
hyp = [blank_id] * context_size hyp = [blank_id] * context_size
# timestamp[i] is the frame index after subsampling
# on which hyp[i] is decoded
timestamp = []
# Maximum symbols per utterance. # Maximum symbols per utterance.
max_sym_per_utt = 1000 max_sym_per_utt = 1000
@ -533,6 +577,7 @@ def greedy_search(
y = logits.argmax().item() y = logits.argmax().item()
if y not in (blank_id, unk_id): if y not in (blank_id, unk_id):
hyp.append(y) hyp.append(y)
timestamp.append(t)
decoder_input = torch.tensor( decoder_input = torch.tensor(
[hyp[-context_size:]], device=device [hyp[-context_size:]], device=device
).reshape(1, context_size) ).reshape(1, context_size)
@ -547,14 +592,21 @@ def greedy_search(
t += 1 t += 1
hyp = hyp[context_size:] # remove blanks hyp = hyp[context_size:] # remove blanks
if not return_timestamps:
return hyp return hyp
else:
return DecodingResults(
tokens=[hyp],
timestamps=[timestamp],
)
def greedy_search_batch( def greedy_search_batch(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor, encoder_out_lens: torch.Tensor,
) -> List[List[int]]: return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args: Args:
model: model:
@ -564,9 +616,12 @@ def greedy_search_batch(
encoder_out_lens: encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding. encoder_out before padding.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
Return a list-of-list of token IDs containing the decoded results. If return_timestamps is False, return the decoded result.
len(ans) equals to encoder_out.size(0). Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
""" """
assert encoder_out.ndim == 3 assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0) assert encoder_out.size(0) >= 1, encoder_out.size(0)
@ -591,6 +646,10 @@ def greedy_search_batch(
hyps = [[blank_id] * context_size for _ in range(N)] hyps = [[blank_id] * context_size for _ in range(N)]
# timestamp[n][i] is the frame index after subsampling
# on which hyp[n][i] is decoded
timestamps = [[] for _ in range(N)]
decoder_input = torch.tensor( decoder_input = torch.tensor(
hyps, hyps,
device=device, device=device,
@ -604,7 +663,7 @@ def greedy_search_batch(
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0 offset = 0
for batch_size in batch_size_list: for (t, batch_size) in enumerate(batch_size_list):
start = offset start = offset
end = offset + batch_size end = offset + batch_size
current_encoder_out = encoder_out.data[start:end] current_encoder_out = encoder_out.data[start:end]
@ -626,6 +685,7 @@ def greedy_search_batch(
for i, v in enumerate(y): for i, v in enumerate(y):
if v not in (blank_id, unk_id): if v not in (blank_id, unk_id):
hyps[i].append(v) hyps[i].append(v)
timestamps[i].append(t)
emitted = True emitted = True
if emitted: if emitted:
# update decoder output # update decoder output
@ -640,11 +700,19 @@ def greedy_search_batch(
sorted_ans = [h[context_size:] for h in hyps] sorted_ans = [h[context_size:] for h in hyps]
ans = [] ans = []
ans_timestamps = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist() unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N): for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]]) ans.append(sorted_ans[unsorted_indices[i]])
ans_timestamps.append(timestamps[unsorted_indices[i]])
if not return_timestamps:
return ans return ans
else:
return DecodingResults(
tokens=ans,
timestamps=ans_timestamps,
)
@dataclass @dataclass
@ -657,6 +725,10 @@ class Hypothesis:
# It contains only one entry. # It contains only one entry.
log_prob: torch.Tensor log_prob: torch.Tensor
# timestamp[i] is the frame index after subsampling
# on which ys[i] is decoded
timestamp: List[int]
state_cost: Optional[NgramLmStateCost] = None state_cost: Optional[NgramLmStateCost] = None
@property @property
@ -806,7 +878,8 @@ def modified_beam_search(
encoder_out_lens: torch.Tensor, encoder_out_lens: torch.Tensor,
beam: int = 4, beam: int = 4,
temperature: float = 1.0, temperature: float = 1.0,
) -> List[List[int]]: return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args: Args:
@ -821,9 +894,12 @@ def modified_beam_search(
Number of active paths during the beam search. Number of active paths during the beam search.
temperature: temperature:
Softmax temperature. Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results If return_timestamps is False, return the decoded result.
for the i-th utterance. Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
""" """
assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0) assert encoder_out.size(0) >= 1, encoder_out.size(0)
@ -851,6 +927,7 @@ def modified_beam_search(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device), log_prob=torch.zeros(1, dtype=torch.float32, device=device),
timestamp=[],
) )
) )
@ -858,7 +935,7 @@ def modified_beam_search(
offset = 0 offset = 0
finalized_B = [] finalized_B = []
for batch_size in batch_size_list: for (t, batch_size) in enumerate(batch_size_list):
start = offset start = offset
end = offset + batch_size end = offset + batch_size
current_encoder_out = encoder_out.data[start:end] current_encoder_out = encoder_out.data[start:end]
@ -936,30 +1013,44 @@ def modified_beam_search(
new_ys = hyp.ys[:] new_ys = hyp.ys[:]
new_token = topk_token_indexes[k] new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:]
if new_token not in (blank_id, unk_id): if new_token not in (blank_id, unk_id):
new_ys.append(new_token) new_ys.append(new_token)
new_timestamp.append(t)
new_log_prob = topk_log_probs[k] new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) new_hyp = Hypothesis(
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
)
B[i].add(new_hyp) B[i].add(new_hyp)
B = B + finalized_B B = B + finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B] best_hyps = [b.get_most_probable(length_norm=True) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps] sorted_ans = [h.ys[context_size:] for h in best_hyps]
sorted_timestamps = [h.timestamp for h in best_hyps]
ans = [] ans = []
ans_timestamps = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist() unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N): for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]]) ans.append(sorted_ans[unsorted_indices[i]])
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
if not return_timestamps:
return ans return ans
else:
return DecodingResults(
tokens=ans,
timestamps=ans_timestamps,
)
def _deprecated_modified_beam_search( def _deprecated_modified_beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
beam: int = 4, beam: int = 4,
) -> List[int]: return_timestamps: bool = False,
) -> Union[List[int], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
It decodes only one utterance at a time. We keep it only for reference. It decodes only one utterance at a time. We keep it only for reference.
@ -974,8 +1065,13 @@ def _deprecated_modified_beam_search(
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam: beam:
Beam size. Beam size.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
Return the decoded result. If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
""" """
assert encoder_out.ndim == 3 assert encoder_out.ndim == 3
@ -995,6 +1091,7 @@ def _deprecated_modified_beam_search(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device), log_prob=torch.zeros(1, dtype=torch.float32, device=device),
timestamp=[],
) )
) )
encoder_out = model.joiner.encoder_proj(encoder_out) encoder_out = model.joiner.encoder_proj(encoder_out)
@ -1053,17 +1150,24 @@ def _deprecated_modified_beam_search(
for i in range(len(topk_hyp_indexes)): for i in range(len(topk_hyp_indexes)):
hyp = A[topk_hyp_indexes[i]] hyp = A[topk_hyp_indexes[i]]
new_ys = hyp.ys[:] new_ys = hyp.ys[:]
new_timestamp = hyp.timestamp[:]
new_token = topk_token_indexes[i] new_token = topk_token_indexes[i]
if new_token not in (blank_id, unk_id): if new_token not in (blank_id, unk_id):
new_ys.append(new_token) new_ys.append(new_token)
new_timestamp.append(t)
new_log_prob = topk_log_probs[i] new_log_prob = topk_log_probs[i]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) new_hyp = Hypothesis(
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
)
B.add(new_hyp) B.add(new_hyp)
best_hyp = B.get_most_probable(length_norm=True) best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
if not return_timestamps:
return ys return ys
else:
return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp])
def beam_search( def beam_search(
@ -1071,7 +1175,8 @@ def beam_search(
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
beam: int = 4, beam: int = 4,
temperature: float = 1.0, temperature: float = 1.0,
) -> List[int]: return_timestamps: bool = False,
) -> Union[List[int], DecodingResults]:
""" """
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
@ -1086,8 +1191,13 @@ def beam_search(
Beam size. Beam size.
temperature: temperature:
Softmax temperature. Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
Return the decoded result. If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
""" """
assert encoder_out.ndim == 3 assert encoder_out.ndim == 3
@ -1114,7 +1224,7 @@ def beam_search(
t = 0 t = 0
B = HypothesisList() B = HypothesisList()
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[]))
max_sym_per_utt = 20000 max_sym_per_utt = 20000
@ -1175,7 +1285,13 @@ def beam_search(
new_y_star_log_prob = y_star.log_prob + skip_log_prob new_y_star_log_prob = y_star.log_prob + skip_log_prob
# ys[:] returns a copy of ys # ys[:] returns a copy of ys
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) B.add(
Hypothesis(
ys=y_star.ys[:],
log_prob=new_y_star_log_prob,
timestamp=y_star.timestamp[:],
)
)
# Second, process other non-blank labels # Second, process other non-blank labels
values, indices = log_prob.topk(beam + 1) values, indices = log_prob.topk(beam + 1)
@ -1184,7 +1300,14 @@ def beam_search(
continue continue
new_ys = y_star.ys + [i] new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v new_log_prob = y_star.log_prob + v
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) new_timestamp = y_star.timestamp + [t]
A.add(
Hypothesis(
ys=new_ys,
log_prob=new_log_prob,
timestamp=new_timestamp,
)
)
# Check whether B contains more than "beam" elements more probable # Check whether B contains more than "beam" elements more probable
# than the most probable in A # than the most probable in A
@ -1200,7 +1323,11 @@ def beam_search(
best_hyp = B.get_most_probable(length_norm=True) best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
if not return_timestamps:
return ys return ys
else:
return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp])
def fast_beam_search_with_nbest_rescoring( def fast_beam_search_with_nbest_rescoring(
@ -1220,7 +1347,8 @@ def fast_beam_search_with_nbest_rescoring(
use_double_scores: bool = True, use_double_scores: bool = True,
nbest_scale: float = 0.5, nbest_scale: float = 0.5,
temperature: float = 1.0, temperature: float = 1.0,
) -> Dict[str, List[List[int]]]: return_timestamps: bool = False,
) -> Dict[str, Union[List[List[int]], DecodingResults]]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using fast beam search, num_path are selected A lattice is first obtained using fast beam search, num_path are selected
and rescored using a given language model. The shortest path within the and rescored using a given language model. The shortest path within the
@ -1262,10 +1390,13 @@ def fast_beam_search_with_nbest_rescoring(
yields more unique paths. yields more unique paths.
temperature: temperature:
Softmax temperature. Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
Return the decoded result in a dict, where the key has the form Return the decoded result in a dict, where the key has the form
'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the 'ngram_lm_scale_xx' and the value is the decoded results
ngram LM scale value used during decoding, i.e., 0.1. optionally with timestamps. `xx` is the ngram LM scale value
used during decoding, i.e., 0.1.
""" """
lattice = fast_beam_search( lattice = fast_beam_search(
model=model, model=model,
@ -1343,16 +1474,18 @@ def fast_beam_search_with_nbest_rescoring(
log_semiring=False, log_semiring=False,
) )
ans: Dict[str, List[List[int]]] = {} ans: Dict[str, Union[List[List[int]], DecodingResults]] = {}
for s in ngram_lm_scale_list: for s in ngram_lm_scale_list:
key = f"ngram_lm_scale_{s}" key = f"ngram_lm_scale_{s}"
tot_scores = am_scores.values + s * ngram_lm_scores tot_scores = am_scores.values + s * ngram_lm_scores
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax() max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes) best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
ans[key] = hyps if not return_timestamps:
ans[key] = get_texts(best_path)
else:
ans[key] = get_texts_with_timestamp(best_path)
return ans return ans
@ -1376,7 +1509,8 @@ def fast_beam_search_with_nbest_rnn_rescoring(
use_double_scores: bool = True, use_double_scores: bool = True,
nbest_scale: float = 0.5, nbest_scale: float = 0.5,
temperature: float = 1.0, temperature: float = 1.0,
) -> Dict[str, List[List[int]]]: return_timestamps: bool = False,
) -> Dict[str, Union[List[List[int]], DecodingResults]]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using fast beam search, num_path are selected A lattice is first obtained using fast beam search, num_path are selected
and rescored using a given language model and a rnn-lm. and rescored using a given language model and a rnn-lm.
@ -1422,10 +1556,13 @@ def fast_beam_search_with_nbest_rnn_rescoring(
yields more unique paths. yields more unique paths.
temperature: temperature:
Softmax temperature. Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns: Returns:
Return the decoded result in a dict, where the key has the form Return the decoded result in a dict, where the key has the form
'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the 'ngram_lm_scale_xx' and the value is the decoded results
ngram LM scale value used during decoding, i.e., 0.1. optionally with timestamps. `xx` is the ngram LM scale value
used during decoding, i.e., 0.1.
""" """
lattice = fast_beam_search( lattice = fast_beam_search(
model=model, model=model,
@ -1537,9 +1674,11 @@ def fast_beam_search_with_nbest_rnn_rescoring(
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax() max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes) best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
ans[key] = hyps if not return_timestamps:
ans[key] = get_texts(best_path)
else:
ans[key] = get_texts_with_timestamp(best_path)
return ans return ans

View File

@ -106,6 +106,22 @@ Usage:
--beam 20.0 \ --beam 20.0 \
--max-contexts 8 \ --max-contexts 8 \
--max-states 64 --max-states 64
To evaluate symbol delay, you should:
(1) Generate cuts with word-time alignments:
./local/add_alignment_librispeech.py \
--alignments-dir data/alignment \
--cuts-in-dir data/fbank \
--cuts-out-dir data/fbank_ali
(2) Set the argument "--manifest-dir data/fbank_ali" while decoding.
For example:
./pruned_transducer_stateless4/decode.py \
--epoch 40 \
--avg 20 \
--exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 600 \
--decoding-method greedy_search \
--manifest-dir data/fbank_ali
""" """
@ -142,10 +158,12 @@ from icefall.checkpoint import (
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
DecodingResults,
parse_hyp_and_timestamp,
setup_logger, setup_logger,
store_transcripts, store_transcripts_and_timestamps,
str2bool, str2bool,
write_error_stats, write_error_stats_with_timestamps,
) )
LOG_EPS = math.log(1e-10) LOG_EPS = math.log(1e-10)
@ -318,7 +336,7 @@ def get_parser():
"--left-context", "--left-context",
type=int, type=int,
default=64, default=64,
help="left context can be seen during decoding (in frames after subsampling)", help="left context can be seen during decoding (in frames after subsampling)", # noqa
) )
parser.add_argument( parser.add_argument(
@ -350,7 +368,7 @@ def decode_one_batch(
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -358,9 +376,10 @@ def decode_one_batch(
if greedy_search is used, it would be "greedy_search" if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be If beam search with a beam size of 7 is used, it would be
"beam_7" "beam_7"
- value: It contains the decoding result. `len(value)` equals to - value: It is a tuple. `len(value[0])` and `len(value[1])` are both
batch size. `value[i]` is the decoding result for the i-th equal to the batch size. `value[0][i]` and `value[1][i]`
utterance in the given batch. are the decoding result and timestamps for the i-th utterance
in the given batch respectively.
Args: Args:
params: params:
It's the return value of :func:`get_params`. It's the return value of :func:`get_params`.
@ -379,8 +398,8 @@ def decode_one_batch(
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result and timestamps. See above description for the
the returned dict. format of the returned dict.
""" """
device = next(model.parameters()).device device = next(model.parameters()).device
feature = batch["inputs"] feature = batch["inputs"]
@ -412,10 +431,8 @@ def decode_one_batch(
x=feature, x_lens=feature_lens x=feature, x_lens=feature_lens
) )
hyps = []
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best( res = fast_beam_search_one_best(
model=model, model=model,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -423,11 +440,10 @@ def decode_one_batch(
beam=params.beam, beam=params.beam,
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
return_timestamps=True,
) )
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG": elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG( res = fast_beam_search_nbest_LG(
model=model, model=model,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -437,11 +453,10 @@ def decode_one_batch(
max_states=params.max_states, max_states=params.max_states,
num_paths=params.num_paths, num_paths=params.num_paths,
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
return_timestamps=True,
) )
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest": elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest( res = fast_beam_search_nbest(
model=model, model=model,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -451,11 +466,10 @@ def decode_one_batch(
max_states=params.max_states, max_states=params.max_states,
num_paths=params.num_paths, num_paths=params.num_paths,
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
return_timestamps=True,
) )
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle": elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle( res = fast_beam_search_nbest_oracle(
model=model, model=model,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -466,56 +480,67 @@ def decode_one_batch(
num_paths=params.num_paths, num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]), ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
return_timestamps=True,
) )
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif ( elif (
params.decoding_method == "greedy_search" params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1 and params.max_sym_per_frame == 1
): ):
hyp_tokens = greedy_search_batch( res = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
return_timestamps=True,
) )
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search( res = modified_beam_search(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
return_timestamps=True,
) )
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
tokens = []
timestamps = []
for i in range(batch_size): for i in range(batch_size):
# fmt: off # fmt: off
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
# fmt: on # fmt: on
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
hyp = greedy_search( res = greedy_search(
model=model, model=model,
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame, max_sym_per_frame=params.max_sym_per_frame,
return_timestamps=True,
) )
elif params.decoding_method == "beam_search": elif params.decoding_method == "beam_search":
hyp = beam_search( res = beam_search(
model=model, model=model,
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
beam=params.beam_size, beam=params.beam_size,
return_timestamps=True,
) )
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
) )
hyps.append(sp.decode(hyp).split()) tokens.extend(res.tokens)
timestamps.extend(res.timestamps)
res = DecodingResults(tokens=tokens, timestamps=timestamps)
hyps, timestamps = parse_hyp_and_timestamp(
decoding_method=params.decoding_method,
res=res,
sp=sp,
subsampling_factor=params.subsampling_factor,
frame_shift_ms=params.frame_shift_ms,
word_table=word_table,
)
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": (hyps, timestamps)}
elif "fast_beam_search" in params.decoding_method: elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_" key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_" key += f"max_contexts_{params.max_contexts}_"
@ -526,9 +551,9 @@ def decode_one_batch(
if "LG" in params.decoding_method: if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}" key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps} return {key: (hyps, timestamps)}
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": (hyps, timestamps)}
def decode_dataset( def decode_dataset(
@ -538,7 +563,9 @@ def decode_dataset(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: ) -> Dict[
str, List[Tuple[str, List[str], List[str], List[float], List[float]]]
]:
"""Decode dataset. """Decode dataset.
Args: Args:
@ -559,9 +586,12 @@ def decode_dataset(
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements: Its value is a list of tuples. Each tuple contains five elements:
The first is the reference transcript, and the second is the - cut_id
predicted result. - reference transcript
- predicted result
- timestamp of reference transcript
- timestamp of predicted result
""" """
num_cuts = 0 num_cuts = 0
@ -580,6 +610,18 @@ def decode_dataset(
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
timestamps_ref = []
for cut in batch["supervisions"]["cut"]:
for s in cut.supervisions:
time = []
if s.alignment is not None and "word" in s.alignment:
time = [
aliword.start
for aliword in s.alignment["word"]
if aliword.symbol != ""
]
timestamps_ref.append(time)
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
model=model, model=model,
@ -589,12 +631,18 @@ def decode_dataset(
batch=batch, batch=batch,
) )
for name, hyps in hyps_dict.items(): for name, (hyps, timestamps_hyp) in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts) and len(timestamps_hyp) == len(
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): timestamps_ref
)
for cut_id, hyp_words, ref_text, time_hyp, time_ref in zip(
cut_ids, hyps, texts, timestamps_hyp, timestamps_ref
):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words)) this_batch.append(
(cut_id, ref_words, hyp_words, time_ref, time_hyp)
)
results[name].extend(this_batch) results[name].extend(this_batch)
@ -612,15 +660,19 @@ def decode_dataset(
def save_results( def save_results(
params: AttributeDict, params: AttributeDict,
test_set_name: str, test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], results_dict: Dict[
str,
List[Tuple[List[str], List[str], List[str], List[float], List[float]]],
],
): ):
test_set_wers = dict() test_set_wers = dict()
test_set_delays = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results) results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts_and_timestamps(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
@ -629,10 +681,11 @@ def save_results(
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer, mean_delay, var_delay = write_error_stats_with_timestamps(
f, f"{test_set_name}-{key}", results, enable_log=True f, f"{test_set_name}-{key}", results, enable_log=True
) )
test_set_wers[key] = wer test_set_wers[key] = wer
test_set_delays[key] = (mean_delay, var_delay)
logging.info("Wrote detailed error stats to {}".format(errs_filename)) logging.info("Wrote detailed error stats to {}".format(errs_filename))
@ -646,6 +699,19 @@ def save_results(
for key, val in test_set_wers: for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f) print("{}\t{}".format(key, val), file=f)
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0])
delays_info = (
params.res_dir
/ f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(delays_info, "w") as f:
print("settings\tsymbol-delay", file=f)
for key, val in test_set_delays:
print(
"{}\tmean: {}s, variance: {}".format(key, val[0], val[1]),
file=f,
)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name) s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name) note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers: for key, val in test_set_wers:
@ -653,6 +719,15 @@ def save_results(
note = "" note = ""
logging.info(s) logging.info(s)
s = "\nFor {}, symbol-delay of different settings are:\n".format(
test_set_name
)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_delays:
s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note)
note = ""
logging.info(s)
@torch.no_grad() @torch.no_grad()
def main(): def main():

View File

@ -386,6 +386,7 @@ def get_params() -> AttributeDict:
""" """
params = AttributeDict( params = AttributeDict(
{ {
"frame_shift_ms": 10.0,
"best_train_loss": float("inf"), "best_train_loss": float("inf"),
"best_valid_loss": float("inf"), "best_valid_loss": float("inf"),
"best_train_epoch": -1, "best_train_epoch": -1,

View File

@ -24,9 +24,10 @@ import re
import subprocess import subprocess
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, Iterable, List, TextIO, Tuple, Union from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
import k2 import k2
import k2.version import k2.version
@ -248,6 +249,86 @@ def get_texts(
return aux_labels.tolist() return aux_labels.tolist()
@dataclass
class DecodingResults:
# Decoded token IDs for each utterance in the batch
tokens: List[List[int]]
# timestamps[i][k] contains the frame number on which tokens[i][k]
# is decoded
timestamps: List[List[int]]
# hyps[i] is the recognition results, i.e., word IDs
# for the i-th utterance with fast_beam_search_nbest_LG.
hyps: Union[List[List[int]], k2.RaggedTensor] = None
def get_tokens_and_timestamps(labels: List[int]) -> Tuple[List[int], List[int]]:
tokens = []
timestamps = []
for i, v in enumerate(labels):
if v != 0:
tokens.append(v)
timestamps.append(i)
return tokens, timestamps
def get_texts_with_timestamp(
best_paths: k2.Fsa, return_ragged: bool = False
) -> DecodingResults:
"""Extract the texts (as word IDs) and timestamps from the best-path FSAs.
Args:
best_paths:
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
containing multiple FSAs, which is expected to be the result
of k2.shortest_path (otherwise the returned values won't
be meaningful).
return_ragged:
True to return a ragged tensor with two axes [utt][word_id].
False to return a list-of-list word IDs.
Returns:
Returns a list of lists of int, containing the label sequences we
decoded.
"""
if isinstance(best_paths.aux_labels, k2.RaggedTensor):
# remove 0's and -1's.
aux_labels = best_paths.aux_labels.remove_values_leq(0)
# TODO: change arcs.shape() to arcs.shape
aux_shape = best_paths.arcs.shape().compose(aux_labels.shape)
# remove the states and arcs axes.
aux_shape = aux_shape.remove_axis(1)
aux_shape = aux_shape.remove_axis(1)
aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values)
else:
# remove axis corresponding to states.
aux_shape = best_paths.arcs.shape().remove_axis(1)
aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels)
# remove 0's and -1's.
aux_labels = aux_labels.remove_values_leq(0)
assert aux_labels.num_axes == 2
labels_shape = best_paths.arcs.shape().remove_axis(1)
labels_list = k2.RaggedTensor(
labels_shape, best_paths.labels.contiguous()
).tolist()
tokens = []
timestamps = []
for labels in labels_list:
token, time = get_tokens_and_timestamps(labels[:-1])
tokens.append(token)
timestamps.append(time)
return DecodingResults(
tokens=tokens,
timestamps=timestamps,
hyps=aux_labels if return_ragged else aux_labels.tolist(),
)
def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]: def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
"""Extract labels or aux_labels from the best-path FSAs. """Extract labels or aux_labels from the best-path FSAs.
@ -352,6 +433,33 @@ def store_transcripts(
print(f"{cut_id}:\thyp={hyp}", file=f) print(f"{cut_id}:\thyp={hyp}", file=f)
def store_transcripts_and_timestamps(
filename: Pathlike,
texts: Iterable[Tuple[str, List[str], List[str], List[float], List[float]]],
) -> None:
"""Save predicted results and reference transcripts as well as their timestamps
to a file.
Args:
filename:
File to save the results to.
texts:
An iterable of tuples. The first element is the cur_id, the second is
the reference transcript and the third element is the predicted result.
Returns:
Return None.
"""
with open(filename, "w") as f:
for cut_id, ref, hyp, time_ref, time_hyp in texts:
print(f"{cut_id}:\tref={ref}", file=f)
print(f"{cut_id}:\thyp={hyp}", file=f)
if len(time_ref) > 0:
s = "[" + ", ".join(["%0.3f" % i for i in time_ref]) + "]"
print(f"{cut_id}:\ttimestamp_ref={s}", file=f)
s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]"
print(f"{cut_id}:\ttimestamp_hyp={s}", file=f)
def write_error_stats( def write_error_stats(
f: TextIO, f: TextIO,
test_set_name: str, test_set_name: str,
@ -519,6 +627,211 @@ def write_error_stats(
return float(tot_err_rate) return float(tot_err_rate)
def write_error_stats_with_timestamps(
f: TextIO,
test_set_name: str,
results: List[Tuple[str, List[str], List[str], List[float], List[float]]],
enable_log: bool = True,
) -> Tuple[float, float, float]:
"""Write statistics based on predicted results and reference transcripts
as well as their timestamps.
It will write the following to the given file:
- WER
- number of insertions, deletions, substitutions, corrects and total
reference words. For example::
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
reference words (2337 correct)
- The difference between the reference transcript and predicted result.
An instance is given below::
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
The above example shows that the reference word is `EDISON`,
but it is predicted to `ADDISON` (a substitution error).
Another example is::
FOR THE FIRST DAY (SIR->*) I THINK
The reference word `SIR` is missing in the predicted
results (a deletion error).
results:
An iterable of tuples. The first element is the cur_id, the second is
the reference transcript and the third element is the predicted result.
enable_log:
If True, also print detailed WER to the console.
Otherwise, it is written only to the given file.
Returns:
Return total word error rate and mean delay.
"""
subs: Dict[Tuple[str, str], int] = defaultdict(int)
ins: Dict[str, int] = defaultdict(int)
dels: Dict[str, int] = defaultdict(int)
# `words` stores counts per word, as follows:
# corr, ref_sub, hyp_sub, ins, dels
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
num_corr = 0
ERR = "*"
# Compute mean alignment delay on the correct words
all_delay = []
for cut_id, ref, hyp, time_ref, time_hyp in results:
ali = kaldialign.align(ref, hyp, ERR)
has_time_ref = len(time_ref) > 0
if has_time_ref:
# pointer to timestamp_hyp
p_hyp = 0
# pointer to timestamp_ref
p_ref = 0
for ref_word, hyp_word in ali:
if ref_word == ERR:
ins[hyp_word] += 1
words[hyp_word][3] += 1
if has_time_ref:
p_hyp += 1
elif hyp_word == ERR:
dels[ref_word] += 1
words[ref_word][4] += 1
if has_time_ref:
p_ref += 1
elif hyp_word != ref_word:
subs[(ref_word, hyp_word)] += 1
words[ref_word][1] += 1
words[hyp_word][2] += 1
if has_time_ref:
p_hyp += 1
p_ref += 1
else:
words[ref_word][0] += 1
num_corr += 1
if has_time_ref:
all_delay.append(time_hyp[p_hyp] - time_ref[p_ref])
p_hyp += 1
p_ref += 1
if has_time_ref:
assert p_hyp == len(hyp), (p_hyp, len(hyp))
assert p_ref == len(ref), (p_ref, len(ref))
ref_len = sum([len(r) for _, r, _, _, _ in results])
sub_errs = sum(subs.values())
ins_errs = sum(ins.values())
del_errs = sum(dels.values())
tot_errs = sub_errs + ins_errs + del_errs
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
mean_delay = "inf"
var_delay = "inf"
num_delay = len(all_delay)
if num_delay > 0:
mean_delay = sum(all_delay) / num_delay
var_delay = sum([(i - mean_delay) ** 2 for i in all_delay]) / num_delay
mean_delay = "%.3f" % mean_delay
var_delay = "%.3f" % var_delay
if enable_log:
logging.info(
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
f"{del_errs} del, {sub_errs} sub ]"
)
logging.info(
f"[{test_set_name}] %symbol-delay mean: {mean_delay}s, variance: {var_delay} " # noqa
f"computed on {num_delay} correct words"
)
print(f"%WER = {tot_err_rate}", file=f)
print(
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
f"{sub_errs} substitutions, over {ref_len} reference "
f"words ({num_corr} correct)",
file=f,
)
print(
"Search below for sections starting with PER-UTT DETAILS:, "
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
file=f,
)
print("", file=f)
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
for cut_id, ref, hyp, _, _ in results:
ali = kaldialign.align(ref, hyp, ERR)
combine_successive_errors = True
if combine_successive_errors:
ali = [[[x], [y]] for x, y in ali]
for i in range(len(ali) - 1):
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
ali[i] = [[], []]
ali = [
[
list(filter(lambda a: a != ERR, x)),
list(filter(lambda a: a != ERR, y)),
]
for x, y in ali
]
ali = list(filter(lambda x: x != [[], []], ali))
ali = [
[
ERR if x == [] else " ".join(x),
ERR if y == [] else " ".join(y),
]
for x, y in ali
]
print(
f"{cut_id}:\t"
+ " ".join(
(
ref_word
if ref_word == hyp_word
else f"({ref_word}->{hyp_word})"
for ref_word, hyp_word in ali
)
),
file=f,
)
print("", file=f)
print("SUBSTITUTIONS: count ref -> hyp", file=f)
for count, (ref, hyp) in sorted(
[(v, k) for k, v in subs.items()], reverse=True
):
print(f"{count} {ref} -> {hyp}", file=f)
print("", file=f)
print("DELETIONS: count ref", file=f)
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
print(f"{count} {ref}", file=f)
print("", file=f)
print("INSERTIONS: count hyp", file=f)
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
print(f"{count} {hyp}", file=f)
print("", file=f)
print(
"PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f
)
for _, word, counts in sorted(
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
):
(corr, ref_sub, hyp_sub, ins, dels) = counts
tot_errs = ref_sub + hyp_sub + ins + dels
ref_count = corr + ref_sub + dels
hyp_count = corr + hyp_sub + ins
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
return float(tot_err_rate), float(mean_delay), float(var_delay)
class MetricsTracker(collections.defaultdict): class MetricsTracker(collections.defaultdict):
def __init__(self): def __init__(self):
# Passing the type 'int' to the base-class constructor # Passing the type 'int' to the base-class constructor
@ -978,6 +1291,137 @@ def display_and_save_batch(
logging.info(f"num tokens: {num_tokens}") logging.info(f"num tokens: {num_tokens}")
def convert_timestamp(
frames: List[int],
subsampling_factor: int,
frame_shift_ms: float = 10,
) -> List[float]:
"""Convert frame numbers to time (in seconds) given subsampling factor
and frame shift (in milliseconds).
Args:
frames:
A list of frame numbers after subsampling.
subsampling_factor:
The subsampling factor of the model.
frame_shift_ms:
Frame shift in milliseconds between two contiguous frames.
Return:
Return the time in seconds corresponding to each given frame.
"""
frame_shift = frame_shift_ms / 1000.0
time = []
for f in frames:
time.append(f * subsampling_factor * frame_shift)
return time
def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]:
"""
Parse timestamp of each word.
Args:
tokens:
List of tokens.
timestamp:
List of timestamp of each token.
Returns:
List of timestamp of each word.
"""
start_token = b"\xe2\x96\x81".decode() # '_'
assert len(tokens) == len(timestamp)
ans = []
for i in range(len(tokens)):
flag = False
if i == 0 or tokens[i].startswith(start_token):
flag = True
if len(tokens[i]) == 1 and tokens[i].startswith(start_token):
# tokens[i] == start_token
if i == len(tokens) - 1:
# it is the last token
flag = False
elif tokens[i + 1].startswith(start_token):
# the next token also starts with start_token
flag = False
if flag:
ans.append(timestamp[i])
return ans
def parse_hyp_and_timestamp(
res: DecodingResults,
decoding_method: str,
sp: spm.SentencePieceProcessor,
subsampling_factor: int,
frame_shift_ms: float = 10,
word_table: Optional[k2.SymbolTable] = None,
) -> Tuple[List[List[str]], List[List[float]]]:
"""Parse hypothesis and timestamp.
Args:
res:
A DecodingResults object.
decoding_method:
Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
sp:
The BPE model.
subsampling_factor:
The integer subsampling factor.
frame_shift_ms:
The float frame shift used for feature extraction.
word_table:
The word symbol table.
Returns:
Return a list of hypothesis and timestamp.
"""
assert decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
hyps = []
timestamps = []
N = len(res.tokens)
assert len(res.timestamps) == N
use_word_table = False
if decoding_method == "fast_beam_search_nbest_LG":
assert word_table is not None
use_word_table = True
for i in range(N):
tokens = sp.id_to_piece(res.tokens[i])
if use_word_table:
words = [word_table[i] for i in res.hyps[i]]
else:
words = sp.decode_pieces(tokens).split()
time = convert_timestamp(
res.timestamps[i], subsampling_factor, frame_shift_ms
)
time = parse_timestamp(tokens, time)
assert len(time) == len(words), (tokens, words)
hyps.append(words)
timestamps.append(time)
return hyps, timestamps
# `is_module_available` is copied from # `is_module_available` is copied from
# https://github.com/pytorch/audio/blob/6bad3a66a7a1c7cc05755e9ee5931b7391d2b94c/torchaudio/_internal/module_utils.py#L9 # https://github.com/pytorch/audio/blob/6bad3a66a7a1c7cc05755e9ee5931b7391d2b94c/torchaudio/_internal/module_utils.py#L9
def is_module_available(*modules: str) -> bool: def is_module_available(*modules: str) -> bool: