mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Get timestamps during decoding (#598)
* print out timestamps during decoding * add word-level alignments * support to compute mean symbol delay with word-level alignments * print variance of symbol delay * update doc * support to compute delay for pruned_transducer_stateless4 * fix bug * add doc
This commit is contained in:
parent
ff3f026381
commit
03668771d7
12
egs/librispeech/ASR/add_alignments.sh
Executable file
12
egs/librispeech/ASR/add_alignments.sh
Executable file
@ -0,0 +1,12 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -eou pipefail
|
||||||
|
|
||||||
|
alignments_dir=data/alignment
|
||||||
|
cuts_in_dir=data/fbank
|
||||||
|
cuts_out_dir=data/fbank_ali
|
||||||
|
|
||||||
|
python3 ./local/add_alignment_librispeech.py \
|
||||||
|
--alignments-dir $alignments_dir \
|
||||||
|
--cuts-in-dir $cuts_in_dir \
|
||||||
|
--cuts-out-dir $cuts_out_dir
|
196
egs/librispeech/ASR/local/add_alignment_librispeech.py
Executable file
196
egs/librispeech/ASR/local/add_alignment_librispeech.py
Executable file
@ -0,0 +1,196 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file adds alignments from https://github.com/CorentinJ/librispeech-alignments # noqa
|
||||||
|
to the existing fbank features dir (e.g., data/fbank)
|
||||||
|
and save cuts to a new dir (e.g., data/fbank_ali).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import zipfile
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from lhotse import CutSet, load_manifest_lazy
|
||||||
|
from lhotse.recipes.librispeech import parse_alignments
|
||||||
|
from lhotse.utils import is_module_available
|
||||||
|
|
||||||
|
LIBRISPEECH_ALIGNMENTS_URL = (
|
||||||
|
"https://drive.google.com/uc?id=1WYfgr31T-PPwMcxuAq09XZfHQO5Mw8fE"
|
||||||
|
)
|
||||||
|
|
||||||
|
DATASET_PARTS = [
|
||||||
|
"dev-clean",
|
||||||
|
"dev-other",
|
||||||
|
"test-clean",
|
||||||
|
"test-other",
|
||||||
|
"train-clean-100",
|
||||||
|
"train-clean-360",
|
||||||
|
"train-other-500",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--alignments-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/alignment",
|
||||||
|
help="The dir to save alignments.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--cuts-in-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/fbank",
|
||||||
|
help="The dir of the existing cuts without alignments.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--cuts-out-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/fbank_ali",
|
||||||
|
help="The dir to save the new cuts with alignments",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def download_alignments(
|
||||||
|
target_dir: str, alignments_url: str = LIBRISPEECH_ALIGNMENTS_URL
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Download and extract the alignments.
|
||||||
|
|
||||||
|
Note: If you can not access drive.google.com, you could download the file
|
||||||
|
`LibriSpeech-Alignments.zip` from huggingface:
|
||||||
|
https://huggingface.co/Zengwei/librispeech-alignments
|
||||||
|
and extract the zip file manually.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_dir:
|
||||||
|
The dir to save alignments.
|
||||||
|
alignments_url:
|
||||||
|
The URL of alignments.
|
||||||
|
"""
|
||||||
|
"""Modified from https://github.com/lhotse-speech/lhotse/blob/master/lhotse/recipes/librispeech.py""" # noqa
|
||||||
|
target_dir = Path(target_dir)
|
||||||
|
target_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
completed_detector = target_dir / ".ali_completed"
|
||||||
|
if completed_detector.is_file():
|
||||||
|
logging.info("The alignment files already exist.")
|
||||||
|
return
|
||||||
|
|
||||||
|
ali_zip_path = target_dir / "LibriSpeech-Alignments.zip"
|
||||||
|
if not ali_zip_path.is_file():
|
||||||
|
assert is_module_available(
|
||||||
|
"gdown"
|
||||||
|
), 'To download LibriSpeech alignments, please install "pip install gdown"' # noqa
|
||||||
|
import gdown
|
||||||
|
|
||||||
|
gdown.download(alignments_url, output=str(ali_zip_path))
|
||||||
|
|
||||||
|
with zipfile.ZipFile(str(ali_zip_path)) as f:
|
||||||
|
f.extractall(path=target_dir)
|
||||||
|
completed_detector.touch()
|
||||||
|
|
||||||
|
|
||||||
|
def add_alignment(
|
||||||
|
alignments_dir: str,
|
||||||
|
cuts_in_dir: str = "data/fbank",
|
||||||
|
cuts_out_dir: str = "data/fbank_ali",
|
||||||
|
dataset_parts: List[str] = DATASET_PARTS,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add alignment info to existing cuts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alignments_dir:
|
||||||
|
The dir of the alignments.
|
||||||
|
cuts_in_dir:
|
||||||
|
The dir of the existing cuts.
|
||||||
|
cuts_out_dir:
|
||||||
|
The dir to save the new cuts with alignments.
|
||||||
|
dataset_parts:
|
||||||
|
Librispeech parts to add alignments.
|
||||||
|
"""
|
||||||
|
alignments_dir = Path(alignments_dir)
|
||||||
|
cuts_in_dir = Path(cuts_in_dir)
|
||||||
|
cuts_out_dir = Path(cuts_out_dir)
|
||||||
|
cuts_out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
for part in dataset_parts:
|
||||||
|
logging.info(f"Processing {part}")
|
||||||
|
|
||||||
|
cuts_in_path = cuts_in_dir / f"librispeech_cuts_{part}.jsonl.gz"
|
||||||
|
if not cuts_in_path.is_file():
|
||||||
|
logging.info(f"{cuts_in_path} does not exist - skipping.")
|
||||||
|
continue
|
||||||
|
cuts_out_path = cuts_out_dir / f"librispeech_cuts_{part}.jsonl.gz"
|
||||||
|
if cuts_out_path.is_file():
|
||||||
|
logging.info(f"{part} already exists - skipping.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# parse alignments
|
||||||
|
alignments = {}
|
||||||
|
part_ali_dir = alignments_dir / "LibriSpeech" / part
|
||||||
|
for ali_path in part_ali_dir.rglob("*.alignment.txt"):
|
||||||
|
ali = parse_alignments(ali_path)
|
||||||
|
alignments.update(ali)
|
||||||
|
logging.info(
|
||||||
|
f"{part} has {len(alignments.keys())} cuts with alignments."
|
||||||
|
)
|
||||||
|
|
||||||
|
# add alignment attribute and write out
|
||||||
|
cuts_in = load_manifest_lazy(cuts_in_path)
|
||||||
|
with CutSet.open_writer(cuts_out_path) as writer:
|
||||||
|
for cut in cuts_in:
|
||||||
|
for idx, subcut in enumerate(cut.supervisions):
|
||||||
|
origin_id = subcut.id.split("_")[0]
|
||||||
|
if origin_id in alignments:
|
||||||
|
ali = alignments[origin_id]
|
||||||
|
else:
|
||||||
|
logging.info(
|
||||||
|
f"Warning: {origin_id} does not has alignment."
|
||||||
|
)
|
||||||
|
ali = []
|
||||||
|
subcut.alignment = {"word": ali}
|
||||||
|
writer.write(cut, flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
|
||||||
|
download_alignments(args.alignments_dir)
|
||||||
|
add_alignment(args.alignments_dir, args.cuts_in_dir, args.cuts_out_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -91,6 +91,22 @@ Usage:
|
|||||||
--beam 20.0 \
|
--beam 20.0 \
|
||||||
--max-contexts 8 \
|
--max-contexts 8 \
|
||||||
--max-states 64
|
--max-states 64
|
||||||
|
|
||||||
|
To evaluate symbol delay, you should:
|
||||||
|
(1) Generate cuts with word-time alignments:
|
||||||
|
./local/add_alignment_librispeech.py \
|
||||||
|
--alignments-dir data/alignment \
|
||||||
|
--cuts-in-dir data/fbank \
|
||||||
|
--cuts-out-dir data/fbank_ali
|
||||||
|
(2) Set the argument "--manifest-dir data/fbank_ali" while decoding.
|
||||||
|
For example:
|
||||||
|
./lstm_transducer_stateless3/decode.py \
|
||||||
|
--epoch 40 \
|
||||||
|
--avg 20 \
|
||||||
|
--exp-dir ./lstm_transducer_stateless3/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method greedy_search \
|
||||||
|
--manifest-dir data/fbank_ali
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -127,10 +143,12 @@ from icefall.checkpoint import (
|
|||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
|
DecodingResults,
|
||||||
|
parse_hyp_and_timestamp,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts_and_timestamps,
|
||||||
str2bool,
|
str2bool,
|
||||||
write_error_stats,
|
write_error_stats_with_timestamps,
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG_EPS = math.log(1e-10)
|
LOG_EPS = math.log(1e-10)
|
||||||
@ -314,7 +332,7 @@ def decode_one_batch(
|
|||||||
batch: dict,
|
batch: dict,
|
||||||
word_table: Optional[k2.SymbolTable] = None,
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
|
|
||||||
@ -322,9 +340,11 @@ def decode_one_batch(
|
|||||||
if greedy_search is used, it would be "greedy_search"
|
if greedy_search is used, it would be "greedy_search"
|
||||||
If beam search with a beam size of 7 is used, it would be
|
If beam search with a beam size of 7 is used, it would be
|
||||||
"beam_7"
|
"beam_7"
|
||||||
- value: It contains the decoding result. `len(value)` equals to
|
- value: It is a tuple. `len(value[0])` and `len(value[1])` are both
|
||||||
batch size. `value[i]` is the decoding result for the i-th
|
equal to the batch size. `value[0][i]` and `value[1][i]`
|
||||||
utterance in the given batch.
|
are the decoding result and timestamps for the i-th utterance
|
||||||
|
in the given batch respectively.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params:
|
params:
|
||||||
It's the return value of :func:`get_params`.
|
It's the return value of :func:`get_params`.
|
||||||
@ -343,8 +363,8 @@ def decode_one_batch(
|
|||||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoding result. See above description for the format of
|
Return the decoding result and timestamps. See above description for the
|
||||||
the returned dict.
|
format of the returned dict.
|
||||||
"""
|
"""
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
@ -370,10 +390,8 @@ def decode_one_batch(
|
|||||||
x=feature, x_lens=feature_lens
|
x=feature, x_lens=feature_lens
|
||||||
)
|
)
|
||||||
|
|
||||||
hyps = []
|
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
res = fast_beam_search_one_best(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -381,11 +399,10 @@ def decode_one_batch(
|
|||||||
beam=params.beam,
|
beam=params.beam,
|
||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||||
hyp_tokens = fast_beam_search_nbest_LG(
|
res = fast_beam_search_nbest_LG(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -395,11 +412,10 @@ def decode_one_batch(
|
|||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
nbest_scale=params.nbest_scale,
|
nbest_scale=params.nbest_scale,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
for hyp in hyp_tokens:
|
|
||||||
hyps.append([word_table[i] for i in hyp])
|
|
||||||
elif params.decoding_method == "fast_beam_search_nbest":
|
elif params.decoding_method == "fast_beam_search_nbest":
|
||||||
hyp_tokens = fast_beam_search_nbest(
|
res = fast_beam_search_nbest(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -409,11 +425,10 @@ def decode_one_batch(
|
|||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
nbest_scale=params.nbest_scale,
|
nbest_scale=params.nbest_scale,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
res = fast_beam_search_nbest_oracle(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -424,56 +439,67 @@ def decode_one_batch(
|
|||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
ref_texts=sp.encode(supervisions["text"]),
|
ref_texts=sp.encode(supervisions["text"]),
|
||||||
nbest_scale=params.nbest_scale,
|
nbest_scale=params.nbest_scale,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
elif (
|
elif (
|
||||||
params.decoding_method == "greedy_search"
|
params.decoding_method == "greedy_search"
|
||||||
and params.max_sym_per_frame == 1
|
and params.max_sym_per_frame == 1
|
||||||
):
|
):
|
||||||
hyp_tokens = greedy_search_batch(
|
res = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
elif params.decoding_method == "modified_beam_search":
|
elif params.decoding_method == "modified_beam_search":
|
||||||
hyp_tokens = modified_beam_search(
|
res = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
else:
|
else:
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
|
tokens = []
|
||||||
|
timestamps = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
|
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
hyp = greedy_search(
|
res = greedy_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out_i,
|
encoder_out=encoder_out_i,
|
||||||
max_sym_per_frame=params.max_sym_per_frame,
|
max_sym_per_frame=params.max_sym_per_frame,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
elif params.decoding_method == "beam_search":
|
elif params.decoding_method == "beam_search":
|
||||||
hyp = beam_search(
|
res = beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out_i,
|
encoder_out=encoder_out_i,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
)
|
)
|
||||||
hyps.append(sp.decode(hyp).split())
|
tokens.extend(res.tokens)
|
||||||
|
timestamps.extend(res.timestamps)
|
||||||
|
res = DecodingResults(tokens=tokens, timestamps=timestamps)
|
||||||
|
|
||||||
|
hyps, timestamps = parse_hyp_and_timestamp(
|
||||||
|
decoding_method=params.decoding_method,
|
||||||
|
res=res,
|
||||||
|
sp=sp,
|
||||||
|
subsampling_factor=params.subsampling_factor,
|
||||||
|
frame_shift_ms=params.frame_shift_ms,
|
||||||
|
word_table=word_table,
|
||||||
|
)
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
return {"greedy_search": hyps}
|
return {"greedy_search": (hyps, timestamps)}
|
||||||
elif "fast_beam_search" in params.decoding_method:
|
elif "fast_beam_search" in params.decoding_method:
|
||||||
key = f"beam_{params.beam}_"
|
key = f"beam_{params.beam}_"
|
||||||
key += f"max_contexts_{params.max_contexts}_"
|
key += f"max_contexts_{params.max_contexts}_"
|
||||||
@ -484,9 +510,9 @@ def decode_one_batch(
|
|||||||
if "LG" in params.decoding_method:
|
if "LG" in params.decoding_method:
|
||||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||||
|
|
||||||
return {key: hyps}
|
return {key: (hyps, timestamps)}
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}": hyps}
|
return {f"beam_size_{params.beam_size}": (hyps, timestamps)}
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
@ -496,7 +522,9 @@ def decode_dataset(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
word_table: Optional[k2.SymbolTable] = None,
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
) -> Dict[
|
||||||
|
str, List[Tuple[str, List[str], List[str], List[float], List[float]]]
|
||||||
|
]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -517,9 +545,12 @@ def decode_dataset(
|
|||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
Its value is a list of tuples. Each tuple contains two elements:
|
Its value is a list of tuples. Each tuple contains five elements:
|
||||||
The first is the reference transcript, and the second is the
|
- cut_id
|
||||||
predicted result.
|
- reference transcript
|
||||||
|
- predicted result
|
||||||
|
- timestamp of reference transcript
|
||||||
|
- timestamp of predicted result
|
||||||
"""
|
"""
|
||||||
num_cuts = 0
|
num_cuts = 0
|
||||||
|
|
||||||
@ -538,6 +569,18 @@ def decode_dataset(
|
|||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
|
timestamps_ref = []
|
||||||
|
for cut in batch["supervisions"]["cut"]:
|
||||||
|
for s in cut.supervisions:
|
||||||
|
time = []
|
||||||
|
if s.alignment is not None and "word" in s.alignment:
|
||||||
|
time = [
|
||||||
|
aliword.start
|
||||||
|
for aliword in s.alignment["word"]
|
||||||
|
if aliword.symbol != ""
|
||||||
|
]
|
||||||
|
timestamps_ref.append(time)
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -547,12 +590,18 @@ def decode_dataset(
|
|||||||
batch=batch,
|
batch=batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, hyps in hyps_dict.items():
|
for name, (hyps, timestamps_hyp) in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts) and len(timestamps_hyp) == len(
|
||||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
timestamps_ref
|
||||||
|
)
|
||||||
|
for cut_id, hyp_words, ref_text, time_hyp, time_ref in zip(
|
||||||
|
cut_ids, hyps, texts, timestamps_hyp, timestamps_ref
|
||||||
|
):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((cut_id, ref_words, hyp_words))
|
this_batch.append(
|
||||||
|
(cut_id, ref_words, hyp_words, time_ref, time_hyp)
|
||||||
|
)
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -570,15 +619,19 @@ def decode_dataset(
|
|||||||
def save_results(
|
def save_results(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
results_dict: Dict[
|
||||||
|
str,
|
||||||
|
List[Tuple[List[str], List[str], List[str], List[float], List[float]]],
|
||||||
|
],
|
||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
|
test_set_delays = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts_and_timestamps(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
@ -587,10 +640,11 @@ def save_results(
|
|||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer, mean_delay, var_delay = write_error_stats_with_timestamps(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
)
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
test_set_delays[key] = (mean_delay, var_delay)
|
||||||
|
|
||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
@ -604,6 +658,19 @@ def save_results(
|
|||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
print("{}\t{}".format(key, val), file=f)
|
print("{}\t{}".format(key, val), file=f)
|
||||||
|
|
||||||
|
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0])
|
||||||
|
delays_info = (
|
||||||
|
params.res_dir
|
||||||
|
/ f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
|
with open(delays_info, "w") as f:
|
||||||
|
print("settings\tsymbol-delay", file=f)
|
||||||
|
for key, val in test_set_delays:
|
||||||
|
print(
|
||||||
|
"{}\tmean: {}s, variance: {}".format(key, val[0], val[1]),
|
||||||
|
file=f,
|
||||||
|
)
|
||||||
|
|
||||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||||
note = "\tbest for {}".format(test_set_name)
|
note = "\tbest for {}".format(test_set_name)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
@ -611,6 +678,15 @@ def save_results(
|
|||||||
note = ""
|
note = ""
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
|
s = "\nFor {}, symbol-delay of different settings are:\n".format(
|
||||||
|
test_set_name
|
||||||
|
)
|
||||||
|
note = "\tbest for {}".format(test_set_name)
|
||||||
|
for key, val in test_set_delays:
|
||||||
|
s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note)
|
||||||
|
note = ""
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
|
@ -377,6 +377,7 @@ def get_params() -> AttributeDict:
|
|||||||
"""
|
"""
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
|
"frame_shift_ms": 10.0,
|
||||||
"best_train_loss": float("inf"),
|
"best_train_loss": float("inf"),
|
||||||
"best_valid_loss": float("inf"),
|
"best_valid_loss": float("inf"),
|
||||||
"best_train_epoch": -1,
|
"best_train_epoch": -1,
|
||||||
|
@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
@ -25,7 +25,13 @@ from model import Transducer
|
|||||||
|
|
||||||
from icefall import NgramLm, NgramLmStateCost
|
from icefall import NgramLm, NgramLmStateCost
|
||||||
from icefall.decode import Nbest, one_best_decoding
|
from icefall.decode import Nbest, one_best_decoding
|
||||||
from icefall.utils import add_eos, add_sos, get_texts
|
from icefall.utils import (
|
||||||
|
DecodingResults,
|
||||||
|
add_eos,
|
||||||
|
add_sos,
|
||||||
|
get_texts,
|
||||||
|
get_texts_with_timestamp,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def fast_beam_search_one_best(
|
def fast_beam_search_one_best(
|
||||||
@ -37,7 +43,8 @@ def fast_beam_search_one_best(
|
|||||||
max_states: int,
|
max_states: int,
|
||||||
max_contexts: int,
|
max_contexts: int,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
) -> List[List[int]]:
|
return_timestamps: bool = False,
|
||||||
|
) -> Union[List[List[int]], DecodingResults]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
A lattice is first obtained using fast beam search, and then
|
A lattice is first obtained using fast beam search, and then
|
||||||
@ -61,8 +68,12 @@ def fast_beam_search_one_best(
|
|||||||
Max contexts pre stream per frame.
|
Max contexts pre stream per frame.
|
||||||
temperature:
|
temperature:
|
||||||
Softmax temperature.
|
Softmax temperature.
|
||||||
|
return_timestamps:
|
||||||
|
Whether to return timestamps.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result.
|
If return_timestamps is False, return the decoded result.
|
||||||
|
Else, return a DecodingResults object containing
|
||||||
|
decoded result and corresponding timestamps.
|
||||||
"""
|
"""
|
||||||
lattice = fast_beam_search(
|
lattice = fast_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -76,8 +87,11 @@ def fast_beam_search_one_best(
|
|||||||
)
|
)
|
||||||
|
|
||||||
best_path = one_best_decoding(lattice)
|
best_path = one_best_decoding(lattice)
|
||||||
hyps = get_texts(best_path)
|
|
||||||
return hyps
|
if not return_timestamps:
|
||||||
|
return get_texts(best_path)
|
||||||
|
else:
|
||||||
|
return get_texts_with_timestamp(best_path)
|
||||||
|
|
||||||
|
|
||||||
def fast_beam_search_nbest_LG(
|
def fast_beam_search_nbest_LG(
|
||||||
@ -92,7 +106,8 @@ def fast_beam_search_nbest_LG(
|
|||||||
nbest_scale: float = 0.5,
|
nbest_scale: float = 0.5,
|
||||||
use_double_scores: bool = True,
|
use_double_scores: bool = True,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
) -> List[List[int]]:
|
return_timestamps: bool = False,
|
||||||
|
) -> Union[List[List[int]], DecodingResults]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
The process to get the results is:
|
The process to get the results is:
|
||||||
@ -129,8 +144,12 @@ def fast_beam_search_nbest_LG(
|
|||||||
single precision.
|
single precision.
|
||||||
temperature:
|
temperature:
|
||||||
Softmax temperature.
|
Softmax temperature.
|
||||||
|
return_timestamps:
|
||||||
|
Whether to return timestamps.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result.
|
If return_timestamps is False, return the decoded result.
|
||||||
|
Else, return a DecodingResults object containing
|
||||||
|
decoded result and corresponding timestamps.
|
||||||
"""
|
"""
|
||||||
lattice = fast_beam_search(
|
lattice = fast_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -195,9 +214,10 @@ def fast_beam_search_nbest_LG(
|
|||||||
best_hyp_indexes = ragged_tot_scores.argmax()
|
best_hyp_indexes = ragged_tot_scores.argmax()
|
||||||
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
|
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
|
||||||
|
|
||||||
hyps = get_texts(best_path)
|
if not return_timestamps:
|
||||||
|
return get_texts(best_path)
|
||||||
return hyps
|
else:
|
||||||
|
return get_texts_with_timestamp(best_path)
|
||||||
|
|
||||||
|
|
||||||
def fast_beam_search_nbest(
|
def fast_beam_search_nbest(
|
||||||
@ -212,7 +232,8 @@ def fast_beam_search_nbest(
|
|||||||
nbest_scale: float = 0.5,
|
nbest_scale: float = 0.5,
|
||||||
use_double_scores: bool = True,
|
use_double_scores: bool = True,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
) -> List[List[int]]:
|
return_timestamps: bool = False,
|
||||||
|
) -> Union[List[List[int]], DecodingResults]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
The process to get the results is:
|
The process to get the results is:
|
||||||
@ -249,8 +270,12 @@ def fast_beam_search_nbest(
|
|||||||
single precision.
|
single precision.
|
||||||
temperature:
|
temperature:
|
||||||
Softmax temperature.
|
Softmax temperature.
|
||||||
|
return_timestamps:
|
||||||
|
Whether to return timestamps.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result.
|
If return_timestamps is False, return the decoded result.
|
||||||
|
Else, return a DecodingResults object containing
|
||||||
|
decoded result and corresponding timestamps.
|
||||||
"""
|
"""
|
||||||
lattice = fast_beam_search(
|
lattice = fast_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -279,9 +304,10 @@ def fast_beam_search_nbest(
|
|||||||
|
|
||||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||||
|
|
||||||
hyps = get_texts(best_path)
|
if not return_timestamps:
|
||||||
|
return get_texts(best_path)
|
||||||
return hyps
|
else:
|
||||||
|
return get_texts_with_timestamp(best_path)
|
||||||
|
|
||||||
|
|
||||||
def fast_beam_search_nbest_oracle(
|
def fast_beam_search_nbest_oracle(
|
||||||
@ -297,7 +323,8 @@ def fast_beam_search_nbest_oracle(
|
|||||||
use_double_scores: bool = True,
|
use_double_scores: bool = True,
|
||||||
nbest_scale: float = 0.5,
|
nbest_scale: float = 0.5,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
) -> List[List[int]]:
|
return_timestamps: bool = False,
|
||||||
|
) -> Union[List[List[int]], DecodingResults]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
A lattice is first obtained using fast beam search, and then
|
A lattice is first obtained using fast beam search, and then
|
||||||
@ -338,8 +365,12 @@ def fast_beam_search_nbest_oracle(
|
|||||||
yields more unique paths.
|
yields more unique paths.
|
||||||
temperature:
|
temperature:
|
||||||
Softmax temperature.
|
Softmax temperature.
|
||||||
|
return_timestamps:
|
||||||
|
Whether to return timestamps.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result.
|
If return_timestamps is False, return the decoded result.
|
||||||
|
Else, return a DecodingResults object containing
|
||||||
|
decoded result and corresponding timestamps.
|
||||||
"""
|
"""
|
||||||
lattice = fast_beam_search(
|
lattice = fast_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -378,8 +409,10 @@ def fast_beam_search_nbest_oracle(
|
|||||||
|
|
||||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||||
|
|
||||||
hyps = get_texts(best_path)
|
if not return_timestamps:
|
||||||
return hyps
|
return get_texts(best_path)
|
||||||
|
else:
|
||||||
|
return get_texts_with_timestamp(best_path)
|
||||||
|
|
||||||
|
|
||||||
def fast_beam_search(
|
def fast_beam_search(
|
||||||
@ -469,8 +502,11 @@ def fast_beam_search(
|
|||||||
|
|
||||||
|
|
||||||
def greedy_search(
|
def greedy_search(
|
||||||
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
|
model: Transducer,
|
||||||
) -> List[int]:
|
encoder_out: torch.Tensor,
|
||||||
|
max_sym_per_frame: int,
|
||||||
|
return_timestamps: bool = False,
|
||||||
|
) -> Union[List[int], DecodingResults]:
|
||||||
"""Greedy search for a single utterance.
|
"""Greedy search for a single utterance.
|
||||||
Args:
|
Args:
|
||||||
model:
|
model:
|
||||||
@ -480,8 +516,12 @@ def greedy_search(
|
|||||||
max_sym_per_frame:
|
max_sym_per_frame:
|
||||||
Maximum number of symbols per frame. If it is set to 0, the WER
|
Maximum number of symbols per frame. If it is set to 0, the WER
|
||||||
would be 100%.
|
would be 100%.
|
||||||
|
return_timestamps:
|
||||||
|
Whether to return timestamps.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result.
|
If return_timestamps is False, return the decoded result.
|
||||||
|
Else, return a DecodingResults object containing
|
||||||
|
decoded result and corresponding timestamps.
|
||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == 3
|
assert encoder_out.ndim == 3
|
||||||
|
|
||||||
@ -507,6 +547,10 @@ def greedy_search(
|
|||||||
t = 0
|
t = 0
|
||||||
hyp = [blank_id] * context_size
|
hyp = [blank_id] * context_size
|
||||||
|
|
||||||
|
# timestamp[i] is the frame index after subsampling
|
||||||
|
# on which hyp[i] is decoded
|
||||||
|
timestamp = []
|
||||||
|
|
||||||
# Maximum symbols per utterance.
|
# Maximum symbols per utterance.
|
||||||
max_sym_per_utt = 1000
|
max_sym_per_utt = 1000
|
||||||
|
|
||||||
@ -533,6 +577,7 @@ def greedy_search(
|
|||||||
y = logits.argmax().item()
|
y = logits.argmax().item()
|
||||||
if y not in (blank_id, unk_id):
|
if y not in (blank_id, unk_id):
|
||||||
hyp.append(y)
|
hyp.append(y)
|
||||||
|
timestamp.append(t)
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
[hyp[-context_size:]], device=device
|
[hyp[-context_size:]], device=device
|
||||||
).reshape(1, context_size)
|
).reshape(1, context_size)
|
||||||
@ -547,14 +592,21 @@ def greedy_search(
|
|||||||
t += 1
|
t += 1
|
||||||
hyp = hyp[context_size:] # remove blanks
|
hyp = hyp[context_size:] # remove blanks
|
||||||
|
|
||||||
|
if not return_timestamps:
|
||||||
return hyp
|
return hyp
|
||||||
|
else:
|
||||||
|
return DecodingResults(
|
||||||
|
tokens=[hyp],
|
||||||
|
timestamps=[timestamp],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def greedy_search_batch(
|
def greedy_search_batch(
|
||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
encoder_out_lens: torch.Tensor,
|
encoder_out_lens: torch.Tensor,
|
||||||
) -> List[List[int]]:
|
return_timestamps: bool = False,
|
||||||
|
) -> Union[List[List[int]], DecodingResults]:
|
||||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||||
Args:
|
Args:
|
||||||
model:
|
model:
|
||||||
@ -564,9 +616,12 @@ def greedy_search_batch(
|
|||||||
encoder_out_lens:
|
encoder_out_lens:
|
||||||
A 1-D tensor of shape (N,), containing number of valid frames in
|
A 1-D tensor of shape (N,), containing number of valid frames in
|
||||||
encoder_out before padding.
|
encoder_out before padding.
|
||||||
|
return_timestamps:
|
||||||
|
Whether to return timestamps.
|
||||||
Returns:
|
Returns:
|
||||||
Return a list-of-list of token IDs containing the decoded results.
|
If return_timestamps is False, return the decoded result.
|
||||||
len(ans) equals to encoder_out.size(0).
|
Else, return a DecodingResults object containing
|
||||||
|
decoded result and corresponding timestamps.
|
||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == 3
|
assert encoder_out.ndim == 3
|
||||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
@ -591,6 +646,10 @@ def greedy_search_batch(
|
|||||||
|
|
||||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||||
|
|
||||||
|
# timestamp[n][i] is the frame index after subsampling
|
||||||
|
# on which hyp[n][i] is decoded
|
||||||
|
timestamps = [[] for _ in range(N)]
|
||||||
|
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
hyps,
|
hyps,
|
||||||
device=device,
|
device=device,
|
||||||
@ -604,7 +663,7 @@ def greedy_search_batch(
|
|||||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
for batch_size in batch_size_list:
|
for (t, batch_size) in enumerate(batch_size_list):
|
||||||
start = offset
|
start = offset
|
||||||
end = offset + batch_size
|
end = offset + batch_size
|
||||||
current_encoder_out = encoder_out.data[start:end]
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
@ -626,6 +685,7 @@ def greedy_search_batch(
|
|||||||
for i, v in enumerate(y):
|
for i, v in enumerate(y):
|
||||||
if v not in (blank_id, unk_id):
|
if v not in (blank_id, unk_id):
|
||||||
hyps[i].append(v)
|
hyps[i].append(v)
|
||||||
|
timestamps[i].append(t)
|
||||||
emitted = True
|
emitted = True
|
||||||
if emitted:
|
if emitted:
|
||||||
# update decoder output
|
# update decoder output
|
||||||
@ -640,11 +700,19 @@ def greedy_search_batch(
|
|||||||
|
|
||||||
sorted_ans = [h[context_size:] for h in hyps]
|
sorted_ans = [h[context_size:] for h in hyps]
|
||||||
ans = []
|
ans = []
|
||||||
|
ans_timestamps = []
|
||||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
for i in range(N):
|
for i in range(N):
|
||||||
ans.append(sorted_ans[unsorted_indices[i]])
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
ans_timestamps.append(timestamps[unsorted_indices[i]])
|
||||||
|
|
||||||
|
if not return_timestamps:
|
||||||
return ans
|
return ans
|
||||||
|
else:
|
||||||
|
return DecodingResults(
|
||||||
|
tokens=ans,
|
||||||
|
timestamps=ans_timestamps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -657,6 +725,10 @@ class Hypothesis:
|
|||||||
# It contains only one entry.
|
# It contains only one entry.
|
||||||
log_prob: torch.Tensor
|
log_prob: torch.Tensor
|
||||||
|
|
||||||
|
# timestamp[i] is the frame index after subsampling
|
||||||
|
# on which ys[i] is decoded
|
||||||
|
timestamp: List[int]
|
||||||
|
|
||||||
state_cost: Optional[NgramLmStateCost] = None
|
state_cost: Optional[NgramLmStateCost] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -806,7 +878,8 @@ def modified_beam_search(
|
|||||||
encoder_out_lens: torch.Tensor,
|
encoder_out_lens: torch.Tensor,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
) -> List[List[int]]:
|
return_timestamps: bool = False,
|
||||||
|
) -> Union[List[List[int]], DecodingResults]:
|
||||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -821,9 +894,12 @@ def modified_beam_search(
|
|||||||
Number of active paths during the beam search.
|
Number of active paths during the beam search.
|
||||||
temperature:
|
temperature:
|
||||||
Softmax temperature.
|
Softmax temperature.
|
||||||
|
return_timestamps:
|
||||||
|
Whether to return timestamps.
|
||||||
Returns:
|
Returns:
|
||||||
Return a list-of-list of token IDs. ans[i] is the decoding results
|
If return_timestamps is False, return the decoded result.
|
||||||
for the i-th utterance.
|
Else, return a DecodingResults object containing
|
||||||
|
decoded result and corresponding timestamps.
|
||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == 3, encoder_out.shape
|
assert encoder_out.ndim == 3, encoder_out.shape
|
||||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
@ -851,6 +927,7 @@ def modified_beam_search(
|
|||||||
Hypothesis(
|
Hypothesis(
|
||||||
ys=[blank_id] * context_size,
|
ys=[blank_id] * context_size,
|
||||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
|
timestamp=[],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -858,7 +935,7 @@ def modified_beam_search(
|
|||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
finalized_B = []
|
finalized_B = []
|
||||||
for batch_size in batch_size_list:
|
for (t, batch_size) in enumerate(batch_size_list):
|
||||||
start = offset
|
start = offset
|
||||||
end = offset + batch_size
|
end = offset + batch_size
|
||||||
current_encoder_out = encoder_out.data[start:end]
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
@ -936,30 +1013,44 @@ def modified_beam_search(
|
|||||||
|
|
||||||
new_ys = hyp.ys[:]
|
new_ys = hyp.ys[:]
|
||||||
new_token = topk_token_indexes[k]
|
new_token = topk_token_indexes[k]
|
||||||
|
new_timestamp = hyp.timestamp[:]
|
||||||
if new_token not in (blank_id, unk_id):
|
if new_token not in (blank_id, unk_id):
|
||||||
new_ys.append(new_token)
|
new_ys.append(new_token)
|
||||||
|
new_timestamp.append(t)
|
||||||
|
|
||||||
new_log_prob = topk_log_probs[k]
|
new_log_prob = topk_log_probs[k]
|
||||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
new_hyp = Hypothesis(
|
||||||
|
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
|
||||||
|
)
|
||||||
B[i].add(new_hyp)
|
B[i].add(new_hyp)
|
||||||
|
|
||||||
B = B + finalized_B
|
B = B + finalized_B
|
||||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||||
|
|
||||||
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||||
|
sorted_timestamps = [h.timestamp for h in best_hyps]
|
||||||
ans = []
|
ans = []
|
||||||
|
ans_timestamps = []
|
||||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
for i in range(N):
|
for i in range(N):
|
||||||
ans.append(sorted_ans[unsorted_indices[i]])
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
|
||||||
|
|
||||||
|
if not return_timestamps:
|
||||||
return ans
|
return ans
|
||||||
|
else:
|
||||||
|
return DecodingResults(
|
||||||
|
tokens=ans,
|
||||||
|
timestamps=ans_timestamps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _deprecated_modified_beam_search(
|
def _deprecated_modified_beam_search(
|
||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
) -> List[int]:
|
return_timestamps: bool = False,
|
||||||
|
) -> Union[List[int], DecodingResults]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
It decodes only one utterance at a time. We keep it only for reference.
|
It decodes only one utterance at a time. We keep it only for reference.
|
||||||
@ -974,8 +1065,13 @@ def _deprecated_modified_beam_search(
|
|||||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||||
beam:
|
beam:
|
||||||
Beam size.
|
Beam size.
|
||||||
|
return_timestamps:
|
||||||
|
Whether to return timestamps.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result.
|
If return_timestamps is False, return the decoded result.
|
||||||
|
Else, return a DecodingResults object containing
|
||||||
|
decoded result and corresponding timestamps.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert encoder_out.ndim == 3
|
assert encoder_out.ndim == 3
|
||||||
@ -995,6 +1091,7 @@ def _deprecated_modified_beam_search(
|
|||||||
Hypothesis(
|
Hypothesis(
|
||||||
ys=[blank_id] * context_size,
|
ys=[blank_id] * context_size,
|
||||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
|
timestamp=[],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||||
@ -1053,17 +1150,24 @@ def _deprecated_modified_beam_search(
|
|||||||
for i in range(len(topk_hyp_indexes)):
|
for i in range(len(topk_hyp_indexes)):
|
||||||
hyp = A[topk_hyp_indexes[i]]
|
hyp = A[topk_hyp_indexes[i]]
|
||||||
new_ys = hyp.ys[:]
|
new_ys = hyp.ys[:]
|
||||||
|
new_timestamp = hyp.timestamp[:]
|
||||||
new_token = topk_token_indexes[i]
|
new_token = topk_token_indexes[i]
|
||||||
if new_token not in (blank_id, unk_id):
|
if new_token not in (blank_id, unk_id):
|
||||||
new_ys.append(new_token)
|
new_ys.append(new_token)
|
||||||
|
new_timestamp.append(t)
|
||||||
new_log_prob = topk_log_probs[i]
|
new_log_prob = topk_log_probs[i]
|
||||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
new_hyp = Hypothesis(
|
||||||
|
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
|
||||||
|
)
|
||||||
B.add(new_hyp)
|
B.add(new_hyp)
|
||||||
|
|
||||||
best_hyp = B.get_most_probable(length_norm=True)
|
best_hyp = B.get_most_probable(length_norm=True)
|
||||||
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
||||||
|
|
||||||
|
if not return_timestamps:
|
||||||
return ys
|
return ys
|
||||||
|
else:
|
||||||
|
return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp])
|
||||||
|
|
||||||
|
|
||||||
def beam_search(
|
def beam_search(
|
||||||
@ -1071,7 +1175,8 @@ def beam_search(
|
|||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
) -> List[int]:
|
return_timestamps: bool = False,
|
||||||
|
) -> Union[List[int], DecodingResults]:
|
||||||
"""
|
"""
|
||||||
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
|
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
|
||||||
|
|
||||||
@ -1086,8 +1191,13 @@ def beam_search(
|
|||||||
Beam size.
|
Beam size.
|
||||||
temperature:
|
temperature:
|
||||||
Softmax temperature.
|
Softmax temperature.
|
||||||
|
return_timestamps:
|
||||||
|
Whether to return timestamps.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result.
|
If return_timestamps is False, return the decoded result.
|
||||||
|
Else, return a DecodingResults object containing
|
||||||
|
decoded result and corresponding timestamps.
|
||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == 3
|
assert encoder_out.ndim == 3
|
||||||
|
|
||||||
@ -1114,7 +1224,7 @@ def beam_search(
|
|||||||
t = 0
|
t = 0
|
||||||
|
|
||||||
B = HypothesisList()
|
B = HypothesisList()
|
||||||
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0))
|
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[]))
|
||||||
|
|
||||||
max_sym_per_utt = 20000
|
max_sym_per_utt = 20000
|
||||||
|
|
||||||
@ -1175,7 +1285,13 @@ def beam_search(
|
|||||||
new_y_star_log_prob = y_star.log_prob + skip_log_prob
|
new_y_star_log_prob = y_star.log_prob + skip_log_prob
|
||||||
|
|
||||||
# ys[:] returns a copy of ys
|
# ys[:] returns a copy of ys
|
||||||
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
|
B.add(
|
||||||
|
Hypothesis(
|
||||||
|
ys=y_star.ys[:],
|
||||||
|
log_prob=new_y_star_log_prob,
|
||||||
|
timestamp=y_star.timestamp[:],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Second, process other non-blank labels
|
# Second, process other non-blank labels
|
||||||
values, indices = log_prob.topk(beam + 1)
|
values, indices = log_prob.topk(beam + 1)
|
||||||
@ -1184,7 +1300,14 @@ def beam_search(
|
|||||||
continue
|
continue
|
||||||
new_ys = y_star.ys + [i]
|
new_ys = y_star.ys + [i]
|
||||||
new_log_prob = y_star.log_prob + v
|
new_log_prob = y_star.log_prob + v
|
||||||
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
|
new_timestamp = y_star.timestamp + [t]
|
||||||
|
A.add(
|
||||||
|
Hypothesis(
|
||||||
|
ys=new_ys,
|
||||||
|
log_prob=new_log_prob,
|
||||||
|
timestamp=new_timestamp,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Check whether B contains more than "beam" elements more probable
|
# Check whether B contains more than "beam" elements more probable
|
||||||
# than the most probable in A
|
# than the most probable in A
|
||||||
@ -1200,7 +1323,11 @@ def beam_search(
|
|||||||
|
|
||||||
best_hyp = B.get_most_probable(length_norm=True)
|
best_hyp = B.get_most_probable(length_norm=True)
|
||||||
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
||||||
|
|
||||||
|
if not return_timestamps:
|
||||||
return ys
|
return ys
|
||||||
|
else:
|
||||||
|
return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp])
|
||||||
|
|
||||||
|
|
||||||
def fast_beam_search_with_nbest_rescoring(
|
def fast_beam_search_with_nbest_rescoring(
|
||||||
@ -1220,7 +1347,8 @@ def fast_beam_search_with_nbest_rescoring(
|
|||||||
use_double_scores: bool = True,
|
use_double_scores: bool = True,
|
||||||
nbest_scale: float = 0.5,
|
nbest_scale: float = 0.5,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
) -> Dict[str, List[List[int]]]:
|
return_timestamps: bool = False,
|
||||||
|
) -> Dict[str, Union[List[List[int]], DecodingResults]]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
A lattice is first obtained using fast beam search, num_path are selected
|
A lattice is first obtained using fast beam search, num_path are selected
|
||||||
and rescored using a given language model. The shortest path within the
|
and rescored using a given language model. The shortest path within the
|
||||||
@ -1262,10 +1390,13 @@ def fast_beam_search_with_nbest_rescoring(
|
|||||||
yields more unique paths.
|
yields more unique paths.
|
||||||
temperature:
|
temperature:
|
||||||
Softmax temperature.
|
Softmax temperature.
|
||||||
|
return_timestamps:
|
||||||
|
Whether to return timestamps.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result in a dict, where the key has the form
|
Return the decoded result in a dict, where the key has the form
|
||||||
'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the
|
'ngram_lm_scale_xx' and the value is the decoded results
|
||||||
ngram LM scale value used during decoding, i.e., 0.1.
|
optionally with timestamps. `xx` is the ngram LM scale value
|
||||||
|
used during decoding, i.e., 0.1.
|
||||||
"""
|
"""
|
||||||
lattice = fast_beam_search(
|
lattice = fast_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -1343,16 +1474,18 @@ def fast_beam_search_with_nbest_rescoring(
|
|||||||
log_semiring=False,
|
log_semiring=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
ans: Dict[str, List[List[int]]] = {}
|
ans: Dict[str, Union[List[List[int]], DecodingResults]] = {}
|
||||||
for s in ngram_lm_scale_list:
|
for s in ngram_lm_scale_list:
|
||||||
key = f"ngram_lm_scale_{s}"
|
key = f"ngram_lm_scale_{s}"
|
||||||
tot_scores = am_scores.values + s * ngram_lm_scores
|
tot_scores = am_scores.values + s * ngram_lm_scores
|
||||||
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||||
max_indexes = ragged_tot_scores.argmax()
|
max_indexes = ragged_tot_scores.argmax()
|
||||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||||
hyps = get_texts(best_path)
|
|
||||||
|
|
||||||
ans[key] = hyps
|
if not return_timestamps:
|
||||||
|
ans[key] = get_texts(best_path)
|
||||||
|
else:
|
||||||
|
ans[key] = get_texts_with_timestamp(best_path)
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
@ -1376,7 +1509,8 @@ def fast_beam_search_with_nbest_rnn_rescoring(
|
|||||||
use_double_scores: bool = True,
|
use_double_scores: bool = True,
|
||||||
nbest_scale: float = 0.5,
|
nbest_scale: float = 0.5,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
) -> Dict[str, List[List[int]]]:
|
return_timestamps: bool = False,
|
||||||
|
) -> Dict[str, Union[List[List[int]], DecodingResults]]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
A lattice is first obtained using fast beam search, num_path are selected
|
A lattice is first obtained using fast beam search, num_path are selected
|
||||||
and rescored using a given language model and a rnn-lm.
|
and rescored using a given language model and a rnn-lm.
|
||||||
@ -1422,10 +1556,13 @@ def fast_beam_search_with_nbest_rnn_rescoring(
|
|||||||
yields more unique paths.
|
yields more unique paths.
|
||||||
temperature:
|
temperature:
|
||||||
Softmax temperature.
|
Softmax temperature.
|
||||||
|
return_timestamps:
|
||||||
|
Whether to return timestamps.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result in a dict, where the key has the form
|
Return the decoded result in a dict, where the key has the form
|
||||||
'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the
|
'ngram_lm_scale_xx' and the value is the decoded results
|
||||||
ngram LM scale value used during decoding, i.e., 0.1.
|
optionally with timestamps. `xx` is the ngram LM scale value
|
||||||
|
used during decoding, i.e., 0.1.
|
||||||
"""
|
"""
|
||||||
lattice = fast_beam_search(
|
lattice = fast_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -1537,9 +1674,11 @@ def fast_beam_search_with_nbest_rnn_rescoring(
|
|||||||
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||||
max_indexes = ragged_tot_scores.argmax()
|
max_indexes = ragged_tot_scores.argmax()
|
||||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||||
hyps = get_texts(best_path)
|
|
||||||
|
|
||||||
ans[key] = hyps
|
if not return_timestamps:
|
||||||
|
ans[key] = get_texts(best_path)
|
||||||
|
else:
|
||||||
|
ans[key] = get_texts_with_timestamp(best_path)
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
@ -106,6 +106,22 @@ Usage:
|
|||||||
--beam 20.0 \
|
--beam 20.0 \
|
||||||
--max-contexts 8 \
|
--max-contexts 8 \
|
||||||
--max-states 64
|
--max-states 64
|
||||||
|
|
||||||
|
To evaluate symbol delay, you should:
|
||||||
|
(1) Generate cuts with word-time alignments:
|
||||||
|
./local/add_alignment_librispeech.py \
|
||||||
|
--alignments-dir data/alignment \
|
||||||
|
--cuts-in-dir data/fbank \
|
||||||
|
--cuts-out-dir data/fbank_ali
|
||||||
|
(2) Set the argument "--manifest-dir data/fbank_ali" while decoding.
|
||||||
|
For example:
|
||||||
|
./pruned_transducer_stateless4/decode.py \
|
||||||
|
--epoch 40 \
|
||||||
|
--avg 20 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless4/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method greedy_search \
|
||||||
|
--manifest-dir data/fbank_ali
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -142,10 +158,12 @@ from icefall.checkpoint import (
|
|||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
|
DecodingResults,
|
||||||
|
parse_hyp_and_timestamp,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts_and_timestamps,
|
||||||
str2bool,
|
str2bool,
|
||||||
write_error_stats,
|
write_error_stats_with_timestamps,
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG_EPS = math.log(1e-10)
|
LOG_EPS = math.log(1e-10)
|
||||||
@ -318,7 +336,7 @@ def get_parser():
|
|||||||
"--left-context",
|
"--left-context",
|
||||||
type=int,
|
type=int,
|
||||||
default=64,
|
default=64,
|
||||||
help="left context can be seen during decoding (in frames after subsampling)",
|
help="left context can be seen during decoding (in frames after subsampling)", # noqa
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -350,7 +368,7 @@ def decode_one_batch(
|
|||||||
batch: dict,
|
batch: dict,
|
||||||
word_table: Optional[k2.SymbolTable] = None,
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
|
|
||||||
@ -358,9 +376,10 @@ def decode_one_batch(
|
|||||||
if greedy_search is used, it would be "greedy_search"
|
if greedy_search is used, it would be "greedy_search"
|
||||||
If beam search with a beam size of 7 is used, it would be
|
If beam search with a beam size of 7 is used, it would be
|
||||||
"beam_7"
|
"beam_7"
|
||||||
- value: It contains the decoding result. `len(value)` equals to
|
- value: It is a tuple. `len(value[0])` and `len(value[1])` are both
|
||||||
batch size. `value[i]` is the decoding result for the i-th
|
equal to the batch size. `value[0][i]` and `value[1][i]`
|
||||||
utterance in the given batch.
|
are the decoding result and timestamps for the i-th utterance
|
||||||
|
in the given batch respectively.
|
||||||
Args:
|
Args:
|
||||||
params:
|
params:
|
||||||
It's the return value of :func:`get_params`.
|
It's the return value of :func:`get_params`.
|
||||||
@ -379,8 +398,8 @@ def decode_one_batch(
|
|||||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoding result. See above description for the format of
|
Return the decoding result and timestamps. See above description for the
|
||||||
the returned dict.
|
format of the returned dict.
|
||||||
"""
|
"""
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
@ -412,10 +431,8 @@ def decode_one_batch(
|
|||||||
x=feature, x_lens=feature_lens
|
x=feature, x_lens=feature_lens
|
||||||
)
|
)
|
||||||
|
|
||||||
hyps = []
|
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
res = fast_beam_search_one_best(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -423,11 +440,10 @@ def decode_one_batch(
|
|||||||
beam=params.beam,
|
beam=params.beam,
|
||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||||
hyp_tokens = fast_beam_search_nbest_LG(
|
res = fast_beam_search_nbest_LG(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -437,11 +453,10 @@ def decode_one_batch(
|
|||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
nbest_scale=params.nbest_scale,
|
nbest_scale=params.nbest_scale,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
for hyp in hyp_tokens:
|
|
||||||
hyps.append([word_table[i] for i in hyp])
|
|
||||||
elif params.decoding_method == "fast_beam_search_nbest":
|
elif params.decoding_method == "fast_beam_search_nbest":
|
||||||
hyp_tokens = fast_beam_search_nbest(
|
res = fast_beam_search_nbest(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -451,11 +466,10 @@ def decode_one_batch(
|
|||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
nbest_scale=params.nbest_scale,
|
nbest_scale=params.nbest_scale,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
res = fast_beam_search_nbest_oracle(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -466,56 +480,67 @@ def decode_one_batch(
|
|||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
ref_texts=sp.encode(supervisions["text"]),
|
ref_texts=sp.encode(supervisions["text"]),
|
||||||
nbest_scale=params.nbest_scale,
|
nbest_scale=params.nbest_scale,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
elif (
|
elif (
|
||||||
params.decoding_method == "greedy_search"
|
params.decoding_method == "greedy_search"
|
||||||
and params.max_sym_per_frame == 1
|
and params.max_sym_per_frame == 1
|
||||||
):
|
):
|
||||||
hyp_tokens = greedy_search_batch(
|
res = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
elif params.decoding_method == "modified_beam_search":
|
elif params.decoding_method == "modified_beam_search":
|
||||||
hyp_tokens = modified_beam_search(
|
res = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
else:
|
else:
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
|
tokens = []
|
||||||
|
timestamps = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
|
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
hyp = greedy_search(
|
res = greedy_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out_i,
|
encoder_out=encoder_out_i,
|
||||||
max_sym_per_frame=params.max_sym_per_frame,
|
max_sym_per_frame=params.max_sym_per_frame,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
elif params.decoding_method == "beam_search":
|
elif params.decoding_method == "beam_search":
|
||||||
hyp = beam_search(
|
res = beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out_i,
|
encoder_out=encoder_out_i,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
)
|
)
|
||||||
hyps.append(sp.decode(hyp).split())
|
tokens.extend(res.tokens)
|
||||||
|
timestamps.extend(res.timestamps)
|
||||||
|
res = DecodingResults(tokens=tokens, timestamps=timestamps)
|
||||||
|
|
||||||
|
hyps, timestamps = parse_hyp_and_timestamp(
|
||||||
|
decoding_method=params.decoding_method,
|
||||||
|
res=res,
|
||||||
|
sp=sp,
|
||||||
|
subsampling_factor=params.subsampling_factor,
|
||||||
|
frame_shift_ms=params.frame_shift_ms,
|
||||||
|
word_table=word_table,
|
||||||
|
)
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
return {"greedy_search": hyps}
|
return {"greedy_search": (hyps, timestamps)}
|
||||||
elif "fast_beam_search" in params.decoding_method:
|
elif "fast_beam_search" in params.decoding_method:
|
||||||
key = f"beam_{params.beam}_"
|
key = f"beam_{params.beam}_"
|
||||||
key += f"max_contexts_{params.max_contexts}_"
|
key += f"max_contexts_{params.max_contexts}_"
|
||||||
@ -526,9 +551,9 @@ def decode_one_batch(
|
|||||||
if "LG" in params.decoding_method:
|
if "LG" in params.decoding_method:
|
||||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||||
|
|
||||||
return {key: hyps}
|
return {key: (hyps, timestamps)}
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}": hyps}
|
return {f"beam_size_{params.beam_size}": (hyps, timestamps)}
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
@ -538,7 +563,9 @@ def decode_dataset(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
word_table: Optional[k2.SymbolTable] = None,
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
) -> Dict[
|
||||||
|
str, List[Tuple[str, List[str], List[str], List[float], List[float]]]
|
||||||
|
]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -559,9 +586,12 @@ def decode_dataset(
|
|||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
Its value is a list of tuples. Each tuple contains two elements:
|
Its value is a list of tuples. Each tuple contains five elements:
|
||||||
The first is the reference transcript, and the second is the
|
- cut_id
|
||||||
predicted result.
|
- reference transcript
|
||||||
|
- predicted result
|
||||||
|
- timestamp of reference transcript
|
||||||
|
- timestamp of predicted result
|
||||||
"""
|
"""
|
||||||
num_cuts = 0
|
num_cuts = 0
|
||||||
|
|
||||||
@ -580,6 +610,18 @@ def decode_dataset(
|
|||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
|
timestamps_ref = []
|
||||||
|
for cut in batch["supervisions"]["cut"]:
|
||||||
|
for s in cut.supervisions:
|
||||||
|
time = []
|
||||||
|
if s.alignment is not None and "word" in s.alignment:
|
||||||
|
time = [
|
||||||
|
aliword.start
|
||||||
|
for aliword in s.alignment["word"]
|
||||||
|
if aliword.symbol != ""
|
||||||
|
]
|
||||||
|
timestamps_ref.append(time)
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -589,12 +631,18 @@ def decode_dataset(
|
|||||||
batch=batch,
|
batch=batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, hyps in hyps_dict.items():
|
for name, (hyps, timestamps_hyp) in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts) and len(timestamps_hyp) == len(
|
||||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
timestamps_ref
|
||||||
|
)
|
||||||
|
for cut_id, hyp_words, ref_text, time_hyp, time_ref in zip(
|
||||||
|
cut_ids, hyps, texts, timestamps_hyp, timestamps_ref
|
||||||
|
):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((cut_id, ref_words, hyp_words))
|
this_batch.append(
|
||||||
|
(cut_id, ref_words, hyp_words, time_ref, time_hyp)
|
||||||
|
)
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -612,15 +660,19 @@ def decode_dataset(
|
|||||||
def save_results(
|
def save_results(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
results_dict: Dict[
|
||||||
|
str,
|
||||||
|
List[Tuple[List[str], List[str], List[str], List[float], List[float]]],
|
||||||
|
],
|
||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
|
test_set_delays = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts_and_timestamps(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
@ -629,10 +681,11 @@ def save_results(
|
|||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer, mean_delay, var_delay = write_error_stats_with_timestamps(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
)
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
test_set_delays[key] = (mean_delay, var_delay)
|
||||||
|
|
||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
@ -646,6 +699,19 @@ def save_results(
|
|||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
print("{}\t{}".format(key, val), file=f)
|
print("{}\t{}".format(key, val), file=f)
|
||||||
|
|
||||||
|
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0])
|
||||||
|
delays_info = (
|
||||||
|
params.res_dir
|
||||||
|
/ f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
|
with open(delays_info, "w") as f:
|
||||||
|
print("settings\tsymbol-delay", file=f)
|
||||||
|
for key, val in test_set_delays:
|
||||||
|
print(
|
||||||
|
"{}\tmean: {}s, variance: {}".format(key, val[0], val[1]),
|
||||||
|
file=f,
|
||||||
|
)
|
||||||
|
|
||||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||||
note = "\tbest for {}".format(test_set_name)
|
note = "\tbest for {}".format(test_set_name)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
@ -653,6 +719,15 @@ def save_results(
|
|||||||
note = ""
|
note = ""
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
|
s = "\nFor {}, symbol-delay of different settings are:\n".format(
|
||||||
|
test_set_name
|
||||||
|
)
|
||||||
|
note = "\tbest for {}".format(test_set_name)
|
||||||
|
for key, val in test_set_delays:
|
||||||
|
s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note)
|
||||||
|
note = ""
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
|
@ -386,6 +386,7 @@ def get_params() -> AttributeDict:
|
|||||||
"""
|
"""
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
|
"frame_shift_ms": 10.0,
|
||||||
"best_train_loss": float("inf"),
|
"best_train_loss": float("inf"),
|
||||||
"best_valid_loss": float("inf"),
|
"best_valid_loss": float("inf"),
|
||||||
"best_train_epoch": -1,
|
"best_train_epoch": -1,
|
||||||
|
446
icefall/utils.py
446
icefall/utils.py
@ -24,9 +24,10 @@ import re
|
|||||||
import subprocess
|
import subprocess
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Iterable, List, TextIO, Tuple, Union
|
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import k2.version
|
import k2.version
|
||||||
@ -248,6 +249,86 @@ def get_texts(
|
|||||||
return aux_labels.tolist()
|
return aux_labels.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DecodingResults:
|
||||||
|
# Decoded token IDs for each utterance in the batch
|
||||||
|
tokens: List[List[int]]
|
||||||
|
|
||||||
|
# timestamps[i][k] contains the frame number on which tokens[i][k]
|
||||||
|
# is decoded
|
||||||
|
timestamps: List[List[int]]
|
||||||
|
|
||||||
|
# hyps[i] is the recognition results, i.e., word IDs
|
||||||
|
# for the i-th utterance with fast_beam_search_nbest_LG.
|
||||||
|
hyps: Union[List[List[int]], k2.RaggedTensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokens_and_timestamps(labels: List[int]) -> Tuple[List[int], List[int]]:
|
||||||
|
tokens = []
|
||||||
|
timestamps = []
|
||||||
|
for i, v in enumerate(labels):
|
||||||
|
if v != 0:
|
||||||
|
tokens.append(v)
|
||||||
|
timestamps.append(i)
|
||||||
|
|
||||||
|
return tokens, timestamps
|
||||||
|
|
||||||
|
|
||||||
|
def get_texts_with_timestamp(
|
||||||
|
best_paths: k2.Fsa, return_ragged: bool = False
|
||||||
|
) -> DecodingResults:
|
||||||
|
"""Extract the texts (as word IDs) and timestamps from the best-path FSAs.
|
||||||
|
Args:
|
||||||
|
best_paths:
|
||||||
|
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
||||||
|
containing multiple FSAs, which is expected to be the result
|
||||||
|
of k2.shortest_path (otherwise the returned values won't
|
||||||
|
be meaningful).
|
||||||
|
return_ragged:
|
||||||
|
True to return a ragged tensor with two axes [utt][word_id].
|
||||||
|
False to return a list-of-list word IDs.
|
||||||
|
Returns:
|
||||||
|
Returns a list of lists of int, containing the label sequences we
|
||||||
|
decoded.
|
||||||
|
"""
|
||||||
|
if isinstance(best_paths.aux_labels, k2.RaggedTensor):
|
||||||
|
# remove 0's and -1's.
|
||||||
|
aux_labels = best_paths.aux_labels.remove_values_leq(0)
|
||||||
|
# TODO: change arcs.shape() to arcs.shape
|
||||||
|
aux_shape = best_paths.arcs.shape().compose(aux_labels.shape)
|
||||||
|
|
||||||
|
# remove the states and arcs axes.
|
||||||
|
aux_shape = aux_shape.remove_axis(1)
|
||||||
|
aux_shape = aux_shape.remove_axis(1)
|
||||||
|
aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values)
|
||||||
|
else:
|
||||||
|
# remove axis corresponding to states.
|
||||||
|
aux_shape = best_paths.arcs.shape().remove_axis(1)
|
||||||
|
aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels)
|
||||||
|
# remove 0's and -1's.
|
||||||
|
aux_labels = aux_labels.remove_values_leq(0)
|
||||||
|
|
||||||
|
assert aux_labels.num_axes == 2
|
||||||
|
|
||||||
|
labels_shape = best_paths.arcs.shape().remove_axis(1)
|
||||||
|
labels_list = k2.RaggedTensor(
|
||||||
|
labels_shape, best_paths.labels.contiguous()
|
||||||
|
).tolist()
|
||||||
|
|
||||||
|
tokens = []
|
||||||
|
timestamps = []
|
||||||
|
for labels in labels_list:
|
||||||
|
token, time = get_tokens_and_timestamps(labels[:-1])
|
||||||
|
tokens.append(token)
|
||||||
|
timestamps.append(time)
|
||||||
|
|
||||||
|
return DecodingResults(
|
||||||
|
tokens=tokens,
|
||||||
|
timestamps=timestamps,
|
||||||
|
hyps=aux_labels if return_ragged else aux_labels.tolist(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
|
def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
|
||||||
"""Extract labels or aux_labels from the best-path FSAs.
|
"""Extract labels or aux_labels from the best-path FSAs.
|
||||||
|
|
||||||
@ -352,6 +433,33 @@ def store_transcripts(
|
|||||||
print(f"{cut_id}:\thyp={hyp}", file=f)
|
print(f"{cut_id}:\thyp={hyp}", file=f)
|
||||||
|
|
||||||
|
|
||||||
|
def store_transcripts_and_timestamps(
|
||||||
|
filename: Pathlike,
|
||||||
|
texts: Iterable[Tuple[str, List[str], List[str], List[float], List[float]]],
|
||||||
|
) -> None:
|
||||||
|
"""Save predicted results and reference transcripts as well as their timestamps
|
||||||
|
to a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
File to save the results to.
|
||||||
|
texts:
|
||||||
|
An iterable of tuples. The first element is the cur_id, the second is
|
||||||
|
the reference transcript and the third element is the predicted result.
|
||||||
|
Returns:
|
||||||
|
Return None.
|
||||||
|
"""
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
for cut_id, ref, hyp, time_ref, time_hyp in texts:
|
||||||
|
print(f"{cut_id}:\tref={ref}", file=f)
|
||||||
|
print(f"{cut_id}:\thyp={hyp}", file=f)
|
||||||
|
if len(time_ref) > 0:
|
||||||
|
s = "[" + ", ".join(["%0.3f" % i for i in time_ref]) + "]"
|
||||||
|
print(f"{cut_id}:\ttimestamp_ref={s}", file=f)
|
||||||
|
s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]"
|
||||||
|
print(f"{cut_id}:\ttimestamp_hyp={s}", file=f)
|
||||||
|
|
||||||
|
|
||||||
def write_error_stats(
|
def write_error_stats(
|
||||||
f: TextIO,
|
f: TextIO,
|
||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
@ -519,6 +627,211 @@ def write_error_stats(
|
|||||||
return float(tot_err_rate)
|
return float(tot_err_rate)
|
||||||
|
|
||||||
|
|
||||||
|
def write_error_stats_with_timestamps(
|
||||||
|
f: TextIO,
|
||||||
|
test_set_name: str,
|
||||||
|
results: List[Tuple[str, List[str], List[str], List[float], List[float]]],
|
||||||
|
enable_log: bool = True,
|
||||||
|
) -> Tuple[float, float, float]:
|
||||||
|
"""Write statistics based on predicted results and reference transcripts
|
||||||
|
as well as their timestamps.
|
||||||
|
|
||||||
|
It will write the following to the given file:
|
||||||
|
|
||||||
|
- WER
|
||||||
|
- number of insertions, deletions, substitutions, corrects and total
|
||||||
|
reference words. For example::
|
||||||
|
|
||||||
|
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
|
||||||
|
reference words (2337 correct)
|
||||||
|
|
||||||
|
- The difference between the reference transcript and predicted result.
|
||||||
|
An instance is given below::
|
||||||
|
|
||||||
|
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
|
||||||
|
|
||||||
|
The above example shows that the reference word is `EDISON`,
|
||||||
|
but it is predicted to `ADDISON` (a substitution error).
|
||||||
|
|
||||||
|
Another example is::
|
||||||
|
|
||||||
|
FOR THE FIRST DAY (SIR->*) I THINK
|
||||||
|
|
||||||
|
The reference word `SIR` is missing in the predicted
|
||||||
|
results (a deletion error).
|
||||||
|
results:
|
||||||
|
An iterable of tuples. The first element is the cur_id, the second is
|
||||||
|
the reference transcript and the third element is the predicted result.
|
||||||
|
enable_log:
|
||||||
|
If True, also print detailed WER to the console.
|
||||||
|
Otherwise, it is written only to the given file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return total word error rate and mean delay.
|
||||||
|
"""
|
||||||
|
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
||||||
|
ins: Dict[str, int] = defaultdict(int)
|
||||||
|
dels: Dict[str, int] = defaultdict(int)
|
||||||
|
|
||||||
|
# `words` stores counts per word, as follows:
|
||||||
|
# corr, ref_sub, hyp_sub, ins, dels
|
||||||
|
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
||||||
|
num_corr = 0
|
||||||
|
ERR = "*"
|
||||||
|
# Compute mean alignment delay on the correct words
|
||||||
|
all_delay = []
|
||||||
|
for cut_id, ref, hyp, time_ref, time_hyp in results:
|
||||||
|
ali = kaldialign.align(ref, hyp, ERR)
|
||||||
|
has_time_ref = len(time_ref) > 0
|
||||||
|
if has_time_ref:
|
||||||
|
# pointer to timestamp_hyp
|
||||||
|
p_hyp = 0
|
||||||
|
# pointer to timestamp_ref
|
||||||
|
p_ref = 0
|
||||||
|
for ref_word, hyp_word in ali:
|
||||||
|
if ref_word == ERR:
|
||||||
|
ins[hyp_word] += 1
|
||||||
|
words[hyp_word][3] += 1
|
||||||
|
if has_time_ref:
|
||||||
|
p_hyp += 1
|
||||||
|
elif hyp_word == ERR:
|
||||||
|
dels[ref_word] += 1
|
||||||
|
words[ref_word][4] += 1
|
||||||
|
if has_time_ref:
|
||||||
|
p_ref += 1
|
||||||
|
elif hyp_word != ref_word:
|
||||||
|
subs[(ref_word, hyp_word)] += 1
|
||||||
|
words[ref_word][1] += 1
|
||||||
|
words[hyp_word][2] += 1
|
||||||
|
if has_time_ref:
|
||||||
|
p_hyp += 1
|
||||||
|
p_ref += 1
|
||||||
|
else:
|
||||||
|
words[ref_word][0] += 1
|
||||||
|
num_corr += 1
|
||||||
|
if has_time_ref:
|
||||||
|
all_delay.append(time_hyp[p_hyp] - time_ref[p_ref])
|
||||||
|
p_hyp += 1
|
||||||
|
p_ref += 1
|
||||||
|
if has_time_ref:
|
||||||
|
assert p_hyp == len(hyp), (p_hyp, len(hyp))
|
||||||
|
assert p_ref == len(ref), (p_ref, len(ref))
|
||||||
|
|
||||||
|
ref_len = sum([len(r) for _, r, _, _, _ in results])
|
||||||
|
sub_errs = sum(subs.values())
|
||||||
|
ins_errs = sum(ins.values())
|
||||||
|
del_errs = sum(dels.values())
|
||||||
|
tot_errs = sub_errs + ins_errs + del_errs
|
||||||
|
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
||||||
|
|
||||||
|
mean_delay = "inf"
|
||||||
|
var_delay = "inf"
|
||||||
|
num_delay = len(all_delay)
|
||||||
|
if num_delay > 0:
|
||||||
|
mean_delay = sum(all_delay) / num_delay
|
||||||
|
var_delay = sum([(i - mean_delay) ** 2 for i in all_delay]) / num_delay
|
||||||
|
mean_delay = "%.3f" % mean_delay
|
||||||
|
var_delay = "%.3f" % var_delay
|
||||||
|
|
||||||
|
if enable_log:
|
||||||
|
logging.info(
|
||||||
|
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
||||||
|
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
||||||
|
f"{del_errs} del, {sub_errs} sub ]"
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
f"[{test_set_name}] %symbol-delay mean: {mean_delay}s, variance: {var_delay} " # noqa
|
||||||
|
f"computed on {num_delay} correct words"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"%WER = {tot_err_rate}", file=f)
|
||||||
|
print(
|
||||||
|
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
|
||||||
|
f"{sub_errs} substitutions, over {ref_len} reference "
|
||||||
|
f"words ({num_corr} correct)",
|
||||||
|
file=f,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"Search below for sections starting with PER-UTT DETAILS:, "
|
||||||
|
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
|
||||||
|
file=f,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
||||||
|
for cut_id, ref, hyp, _, _ in results:
|
||||||
|
ali = kaldialign.align(ref, hyp, ERR)
|
||||||
|
combine_successive_errors = True
|
||||||
|
if combine_successive_errors:
|
||||||
|
ali = [[[x], [y]] for x, y in ali]
|
||||||
|
for i in range(len(ali) - 1):
|
||||||
|
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
|
||||||
|
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
|
||||||
|
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
|
||||||
|
ali[i] = [[], []]
|
||||||
|
ali = [
|
||||||
|
[
|
||||||
|
list(filter(lambda a: a != ERR, x)),
|
||||||
|
list(filter(lambda a: a != ERR, y)),
|
||||||
|
]
|
||||||
|
for x, y in ali
|
||||||
|
]
|
||||||
|
ali = list(filter(lambda x: x != [[], []], ali))
|
||||||
|
ali = [
|
||||||
|
[
|
||||||
|
ERR if x == [] else " ".join(x),
|
||||||
|
ERR if y == [] else " ".join(y),
|
||||||
|
]
|
||||||
|
for x, y in ali
|
||||||
|
]
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"{cut_id}:\t"
|
||||||
|
+ " ".join(
|
||||||
|
(
|
||||||
|
ref_word
|
||||||
|
if ref_word == hyp_word
|
||||||
|
else f"({ref_word}->{hyp_word})"
|
||||||
|
for ref_word, hyp_word in ali
|
||||||
|
)
|
||||||
|
),
|
||||||
|
file=f,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
||||||
|
|
||||||
|
for count, (ref, hyp) in sorted(
|
||||||
|
[(v, k) for k, v in subs.items()], reverse=True
|
||||||
|
):
|
||||||
|
print(f"{count} {ref} -> {hyp}", file=f)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("DELETIONS: count ref", file=f)
|
||||||
|
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
|
||||||
|
print(f"{count} {ref}", file=f)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("INSERTIONS: count hyp", file=f)
|
||||||
|
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
|
||||||
|
print(f"{count} {hyp}", file=f)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print(
|
||||||
|
"PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f
|
||||||
|
)
|
||||||
|
for _, word, counts in sorted(
|
||||||
|
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
||||||
|
):
|
||||||
|
(corr, ref_sub, hyp_sub, ins, dels) = counts
|
||||||
|
tot_errs = ref_sub + hyp_sub + ins + dels
|
||||||
|
ref_count = corr + ref_sub + dels
|
||||||
|
hyp_count = corr + hyp_sub + ins
|
||||||
|
|
||||||
|
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||||
|
return float(tot_err_rate), float(mean_delay), float(var_delay)
|
||||||
|
|
||||||
|
|
||||||
class MetricsTracker(collections.defaultdict):
|
class MetricsTracker(collections.defaultdict):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Passing the type 'int' to the base-class constructor
|
# Passing the type 'int' to the base-class constructor
|
||||||
@ -978,6 +1291,137 @@ def display_and_save_batch(
|
|||||||
logging.info(f"num tokens: {num_tokens}")
|
logging.info(f"num tokens: {num_tokens}")
|
||||||
|
|
||||||
|
|
||||||
|
def convert_timestamp(
|
||||||
|
frames: List[int],
|
||||||
|
subsampling_factor: int,
|
||||||
|
frame_shift_ms: float = 10,
|
||||||
|
) -> List[float]:
|
||||||
|
"""Convert frame numbers to time (in seconds) given subsampling factor
|
||||||
|
and frame shift (in milliseconds).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frames:
|
||||||
|
A list of frame numbers after subsampling.
|
||||||
|
subsampling_factor:
|
||||||
|
The subsampling factor of the model.
|
||||||
|
frame_shift_ms:
|
||||||
|
Frame shift in milliseconds between two contiguous frames.
|
||||||
|
Return:
|
||||||
|
Return the time in seconds corresponding to each given frame.
|
||||||
|
"""
|
||||||
|
frame_shift = frame_shift_ms / 1000.0
|
||||||
|
time = []
|
||||||
|
for f in frames:
|
||||||
|
time.append(f * subsampling_factor * frame_shift)
|
||||||
|
|
||||||
|
return time
|
||||||
|
|
||||||
|
|
||||||
|
def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]:
|
||||||
|
"""
|
||||||
|
Parse timestamp of each word.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens:
|
||||||
|
List of tokens.
|
||||||
|
timestamp:
|
||||||
|
List of timestamp of each token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of timestamp of each word.
|
||||||
|
"""
|
||||||
|
start_token = b"\xe2\x96\x81".decode() # '_'
|
||||||
|
assert len(tokens) == len(timestamp)
|
||||||
|
ans = []
|
||||||
|
for i in range(len(tokens)):
|
||||||
|
flag = False
|
||||||
|
if i == 0 or tokens[i].startswith(start_token):
|
||||||
|
flag = True
|
||||||
|
if len(tokens[i]) == 1 and tokens[i].startswith(start_token):
|
||||||
|
# tokens[i] == start_token
|
||||||
|
if i == len(tokens) - 1:
|
||||||
|
# it is the last token
|
||||||
|
flag = False
|
||||||
|
elif tokens[i + 1].startswith(start_token):
|
||||||
|
# the next token also starts with start_token
|
||||||
|
flag = False
|
||||||
|
if flag:
|
||||||
|
ans.append(timestamp[i])
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def parse_hyp_and_timestamp(
|
||||||
|
res: DecodingResults,
|
||||||
|
decoding_method: str,
|
||||||
|
sp: spm.SentencePieceProcessor,
|
||||||
|
subsampling_factor: int,
|
||||||
|
frame_shift_ms: float = 10,
|
||||||
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
|
) -> Tuple[List[List[str]], List[List[float]]]:
|
||||||
|
"""Parse hypothesis and timestamp.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
res:
|
||||||
|
A DecodingResults object.
|
||||||
|
decoding_method:
|
||||||
|
Possible values are:
|
||||||
|
- greedy_search
|
||||||
|
- beam_search
|
||||||
|
- modified_beam_search
|
||||||
|
- fast_beam_search
|
||||||
|
- fast_beam_search_nbest
|
||||||
|
- fast_beam_search_nbest_oracle
|
||||||
|
- fast_beam_search_nbest_LG
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
|
subsampling_factor:
|
||||||
|
The integer subsampling factor.
|
||||||
|
frame_shift_ms:
|
||||||
|
The float frame shift used for feature extraction.
|
||||||
|
word_table:
|
||||||
|
The word symbol table.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a list of hypothesis and timestamp.
|
||||||
|
"""
|
||||||
|
assert decoding_method in (
|
||||||
|
"greedy_search",
|
||||||
|
"beam_search",
|
||||||
|
"fast_beam_search",
|
||||||
|
"fast_beam_search_nbest",
|
||||||
|
"fast_beam_search_nbest_LG",
|
||||||
|
"fast_beam_search_nbest_oracle",
|
||||||
|
"modified_beam_search",
|
||||||
|
)
|
||||||
|
|
||||||
|
hyps = []
|
||||||
|
timestamps = []
|
||||||
|
|
||||||
|
N = len(res.tokens)
|
||||||
|
assert len(res.timestamps) == N
|
||||||
|
use_word_table = False
|
||||||
|
if decoding_method == "fast_beam_search_nbest_LG":
|
||||||
|
assert word_table is not None
|
||||||
|
use_word_table = True
|
||||||
|
|
||||||
|
for i in range(N):
|
||||||
|
tokens = sp.id_to_piece(res.tokens[i])
|
||||||
|
if use_word_table:
|
||||||
|
words = [word_table[i] for i in res.hyps[i]]
|
||||||
|
else:
|
||||||
|
words = sp.decode_pieces(tokens).split()
|
||||||
|
time = convert_timestamp(
|
||||||
|
res.timestamps[i], subsampling_factor, frame_shift_ms
|
||||||
|
)
|
||||||
|
time = parse_timestamp(tokens, time)
|
||||||
|
assert len(time) == len(words), (tokens, words)
|
||||||
|
|
||||||
|
hyps.append(words)
|
||||||
|
timestamps.append(time)
|
||||||
|
|
||||||
|
return hyps, timestamps
|
||||||
|
|
||||||
|
|
||||||
# `is_module_available` is copied from
|
# `is_module_available` is copied from
|
||||||
# https://github.com/pytorch/audio/blob/6bad3a66a7a1c7cc05755e9ee5931b7391d2b94c/torchaudio/_internal/module_utils.py#L9
|
# https://github.com/pytorch/audio/blob/6bad3a66a7a1c7cc05755e9ee5931b7391d2b94c/torchaudio/_internal/module_utils.py#L9
|
||||||
def is_module_available(*modules: str) -> bool:
|
def is_module_available(*modules: str) -> bool:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user