mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Merge branch 'k2-fsa:master' into master
This commit is contained in:
commit
3b142dbef3
@ -13,9 +13,12 @@ cd egs/librispeech/ASR
|
||||
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-04-29
|
||||
|
||||
log "Downloading pre-trained model from $repo_url"
|
||||
git lfs install
|
||||
git clone $repo_url
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/pretrained-epoch-25-avg-6.pt"
|
||||
popd
|
||||
|
||||
log "Display test files"
|
||||
tree $repo/
|
||||
|
@ -52,6 +52,9 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
|
||||
if [ ! -f $dl_dir/lm/3-gram.unpruned.arpa ]; then
|
||||
git clone https://huggingface.co/pkufool/aishell_lm $dl_dir/lm
|
||||
pushd $dl_dir/lm
|
||||
git lfs pull --include "3-gram.unpruned.arpa"
|
||||
popd
|
||||
fi
|
||||
fi
|
||||
|
||||
|
3
egs/csj/ASR/.gitignore
vendored
3
egs/csj/ASR/.gitignore
vendored
@ -1,7 +1,8 @@
|
||||
librispeech_*.*
|
||||
librispeech_*
|
||||
todelete*
|
||||
lang*
|
||||
notify_tg.py
|
||||
finetune_*
|
||||
misc.ini
|
||||
.vscode/*
|
||||
offline/*
|
@ -1 +0,0 @@
|
||||
../../../librispeech/ASR/local/compute_fbank_musan.py
|
127
egs/csj/ASR/local/compute_fbank_musan.py
Normal file
127
egs/csj/ASR/local/compute_fbank_musan.py
Normal file
@ -0,0 +1,127 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, combine
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
|
||||
|
||||
ARGPARSE_DESCRIPTION = """
|
||||
This file computes fbank features of the musan dataset.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_musan(manifest_dir: Path, fbank_dir: Path):
|
||||
# src_dir = Path("data/manifests")
|
||||
# output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
num_mel_bins = 80
|
||||
|
||||
dataset_parts = (
|
||||
"music",
|
||||
"speech",
|
||||
"noise",
|
||||
)
|
||||
prefix = "musan"
|
||||
suffix = "jsonl.gz"
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts,
|
||||
output_dir=manifest_dir,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
)
|
||||
assert manifests is not None
|
||||
|
||||
assert len(manifests) == len(dataset_parts), (
|
||||
len(manifests),
|
||||
len(dataset_parts),
|
||||
list(manifests.keys()),
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
musan_cuts_path = fbank_dir / "musan_cuts.jsonl.gz"
|
||||
|
||||
if musan_cuts_path.is_file():
|
||||
logging.info(f"{musan_cuts_path} already exists - skipping")
|
||||
return
|
||||
|
||||
logging.info("Extracting features for Musan")
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
# create chunks of Musan with duration 5 - 10 seconds
|
||||
musan_cuts = (
|
||||
CutSet.from_manifests(
|
||||
recordings=combine(
|
||||
part["recordings"] for part in manifests.values()
|
||||
)
|
||||
)
|
||||
.cut_into_windows(10.0)
|
||||
.filter(lambda c: c.duration > 5)
|
||||
.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
storage_path=f"{fbank_dir}/musan_feats",
|
||||
num_jobs=num_jobs if ex is None else 80,
|
||||
executor=ex,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
)
|
||||
musan_cuts.to_file(musan_cuts_path)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description=ARGPARSE_DESCRIPTION,
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--manifest-dir", type=Path, help="Path to save manifests"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fbank-dir", type=Path, help="Path to save fbank features"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
compute_fbank_musan(args.manifest_dir, args.fbank_dir)
|
@ -64,9 +64,9 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
# Example: lhotse prepare csj $csj_dir $trans_dir $csj_manifest_dir -c local/conf/disfluent.ini
|
||||
# NOTE: In case multiple config files are supplied, the second config file and onwards will inherit
|
||||
# the segment boundaries of the first config file.
|
||||
if [ ! -e $csj_manifest_dir/.librispeech.done ]; then
|
||||
if [ ! -e $csj_manifest_dir/.csj.done ]; then
|
||||
lhotse prepare csj $csj_dir $trans_dir $csj_manifest_dir -j 4
|
||||
touch $csj_manifest_dir/.librispeech.done
|
||||
touch $csj_manifest_dir/.csj.done
|
||||
fi
|
||||
fi
|
||||
|
||||
|
@ -1304,7 +1304,7 @@ results at:
|
||||
|
||||
##### Baseline-2
|
||||
|
||||
It has 88.98 M parameters. Compared to the model in pruned_transducer_stateless2, its has more
|
||||
It has 87.8 M parameters. Compared to the model in pruned_transducer_stateless2, its has more
|
||||
layers (24 v.s 12) but a narrower model (1536 feedforward dim and 384 encoder dim vs 2048 feed forward dim and 512 encoder dim).
|
||||
|
||||
| | test-clean | test-other | comment |
|
||||
|
12
egs/librispeech/ASR/add_alignments.sh
Executable file
12
egs/librispeech/ASR/add_alignments.sh
Executable file
@ -0,0 +1,12 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
alignments_dir=data/alignment
|
||||
cuts_in_dir=data/fbank
|
||||
cuts_out_dir=data/fbank_ali
|
||||
|
||||
python3 ./local/add_alignment_librispeech.py \
|
||||
--alignments-dir $alignments_dir \
|
||||
--cuts-in-dir $cuts_in_dir \
|
||||
--cuts-out-dir $cuts_out_dir
|
20
egs/librispeech/ASR/generate-lm.sh
Executable file
20
egs/librispeech/ASR/generate-lm.sh
Executable file
@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
lang_dir=data/lang_bpe_500
|
||||
|
||||
for ngram in 2 3 5; do
|
||||
if [ ! -f $lang_dir/${ngram}gram.arpa ]; then
|
||||
./shared/make_kn_lm.py \
|
||||
-ngram-order ${ngram} \
|
||||
-text $lang_dir/transcript_tokens.txt \
|
||||
-lm $lang_dir/${ngram}gram.arpa
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/${ngram}gram.fst.txt ]; then
|
||||
python3 -m kaldilm \
|
||||
--read-symbol-table="$lang_dir/tokens.txt" \
|
||||
--disambig-symbol='#0' \
|
||||
--max-order=${ngram} \
|
||||
$lang_dir/${ngram}gram.arpa > $lang_dir/${ngram}gram.fst.txt
|
||||
fi
|
||||
done
|
196
egs/librispeech/ASR/local/add_alignment_librispeech.py
Executable file
196
egs/librispeech/ASR/local/add_alignment_librispeech.py
Executable file
@ -0,0 +1,196 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file adds alignments from https://github.com/CorentinJ/librispeech-alignments # noqa
|
||||
to the existing fbank features dir (e.g., data/fbank)
|
||||
and save cuts to a new dir (e.g., data/fbank_ali).
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
from lhotse.recipes.librispeech import parse_alignments
|
||||
from lhotse.utils import is_module_available
|
||||
|
||||
LIBRISPEECH_ALIGNMENTS_URL = (
|
||||
"https://drive.google.com/uc?id=1WYfgr31T-PPwMcxuAq09XZfHQO5Mw8fE"
|
||||
)
|
||||
|
||||
DATASET_PARTS = [
|
||||
"dev-clean",
|
||||
"dev-other",
|
||||
"test-clean",
|
||||
"test-other",
|
||||
"train-clean-100",
|
||||
"train-clean-360",
|
||||
"train-other-500",
|
||||
]
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--alignments-dir",
|
||||
type=str,
|
||||
default="data/alignment",
|
||||
help="The dir to save alignments.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cuts-in-dir",
|
||||
type=str,
|
||||
default="data/fbank",
|
||||
help="The dir of the existing cuts without alignments.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cuts-out-dir",
|
||||
type=str,
|
||||
default="data/fbank_ali",
|
||||
help="The dir to save the new cuts with alignments",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def download_alignments(
|
||||
target_dir: str, alignments_url: str = LIBRISPEECH_ALIGNMENTS_URL
|
||||
):
|
||||
"""
|
||||
Download and extract the alignments.
|
||||
|
||||
Note: If you can not access drive.google.com, you could download the file
|
||||
`LibriSpeech-Alignments.zip` from huggingface:
|
||||
https://huggingface.co/Zengwei/librispeech-alignments
|
||||
and extract the zip file manually.
|
||||
|
||||
Args:
|
||||
target_dir:
|
||||
The dir to save alignments.
|
||||
alignments_url:
|
||||
The URL of alignments.
|
||||
"""
|
||||
"""Modified from https://github.com/lhotse-speech/lhotse/blob/master/lhotse/recipes/librispeech.py""" # noqa
|
||||
target_dir = Path(target_dir)
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
completed_detector = target_dir / ".ali_completed"
|
||||
if completed_detector.is_file():
|
||||
logging.info("The alignment files already exist.")
|
||||
return
|
||||
|
||||
ali_zip_path = target_dir / "LibriSpeech-Alignments.zip"
|
||||
if not ali_zip_path.is_file():
|
||||
assert is_module_available(
|
||||
"gdown"
|
||||
), 'To download LibriSpeech alignments, please install "pip install gdown"' # noqa
|
||||
import gdown
|
||||
|
||||
gdown.download(alignments_url, output=str(ali_zip_path))
|
||||
|
||||
with zipfile.ZipFile(str(ali_zip_path)) as f:
|
||||
f.extractall(path=target_dir)
|
||||
completed_detector.touch()
|
||||
|
||||
|
||||
def add_alignment(
|
||||
alignments_dir: str,
|
||||
cuts_in_dir: str = "data/fbank",
|
||||
cuts_out_dir: str = "data/fbank_ali",
|
||||
dataset_parts: List[str] = DATASET_PARTS,
|
||||
):
|
||||
"""
|
||||
Add alignment info to existing cuts.
|
||||
|
||||
Args:
|
||||
alignments_dir:
|
||||
The dir of the alignments.
|
||||
cuts_in_dir:
|
||||
The dir of the existing cuts.
|
||||
cuts_out_dir:
|
||||
The dir to save the new cuts with alignments.
|
||||
dataset_parts:
|
||||
Librispeech parts to add alignments.
|
||||
"""
|
||||
alignments_dir = Path(alignments_dir)
|
||||
cuts_in_dir = Path(cuts_in_dir)
|
||||
cuts_out_dir = Path(cuts_out_dir)
|
||||
cuts_out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for part in dataset_parts:
|
||||
logging.info(f"Processing {part}")
|
||||
|
||||
cuts_in_path = cuts_in_dir / f"librispeech_cuts_{part}.jsonl.gz"
|
||||
if not cuts_in_path.is_file():
|
||||
logging.info(f"{cuts_in_path} does not exist - skipping.")
|
||||
continue
|
||||
cuts_out_path = cuts_out_dir / f"librispeech_cuts_{part}.jsonl.gz"
|
||||
if cuts_out_path.is_file():
|
||||
logging.info(f"{part} already exists - skipping.")
|
||||
continue
|
||||
|
||||
# parse alignments
|
||||
alignments = {}
|
||||
part_ali_dir = alignments_dir / "LibriSpeech" / part
|
||||
for ali_path in part_ali_dir.rglob("*.alignment.txt"):
|
||||
ali = parse_alignments(ali_path)
|
||||
alignments.update(ali)
|
||||
logging.info(
|
||||
f"{part} has {len(alignments.keys())} cuts with alignments."
|
||||
)
|
||||
|
||||
# add alignment attribute and write out
|
||||
cuts_in = load_manifest_lazy(cuts_in_path)
|
||||
with CutSet.open_writer(cuts_out_path) as writer:
|
||||
for cut in cuts_in:
|
||||
for idx, subcut in enumerate(cut.supervisions):
|
||||
origin_id = subcut.id.split("_")[0]
|
||||
if origin_id in alignments:
|
||||
ali = alignments[origin_id]
|
||||
else:
|
||||
logging.info(
|
||||
f"Warning: {origin_id} does not has alignment."
|
||||
)
|
||||
ali = []
|
||||
subcut.alignment = {"word": ali}
|
||||
writer.write(cut, flush=True)
|
||||
|
||||
|
||||
def main():
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
download_alignments(args.alignments_dir)
|
||||
add_alignment(args.alignments_dir, args.cuts_in_dir, args.cuts_out_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -115,10 +115,12 @@ from beam_search import (
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
modified_beam_search_ngram_rescoring,
|
||||
)
|
||||
from librispeech import LibriSpeech
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall import NgramLm
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
@ -214,6 +216,7 @@ def get_parser():
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
- modified_beam_search_ngram_rescoring
|
||||
If you use fast_beam_search_nbest_LG, you have to specify
|
||||
`--lang-dir`, which should contain `LG.pt`.
|
||||
""",
|
||||
@ -303,6 +306,22 @@ def get_parser():
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens-ngram",
|
||||
type=int,
|
||||
default=3,
|
||||
help="""Token Ngram used for rescoring.
|
||||
Used only when the decoding method is modified_beam_search_ngram_rescoring""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--backoff-id",
|
||||
type=int,
|
||||
default=500,
|
||||
help="""ID of the backoff symbol.
|
||||
Used only when the decoding method is modified_beam_search_ngram_rescoring""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -315,6 +334,8 @@ def decode_one_batch(
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
ngram_lm: Optional[NgramLm] = None,
|
||||
ngram_lm_scale: float = 1.0,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
@ -448,6 +469,17 @@ def decode_one_batch(
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search_ngram_rescoring":
|
||||
hyp_tokens = modified_beam_search_ngram_rescoring(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
@ -497,6 +529,8 @@ def decode_dataset(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
ngram_lm: Optional[NgramLm] = None,
|
||||
ngram_lm_scale: float = 1.0,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
@ -546,6 +580,8 @@ def decode_dataset(
|
||||
decoding_graph=decoding_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
@ -631,6 +667,7 @@ def main():
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
"modified_beam_search_ngram_rescoring",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
@ -655,6 +692,7 @@ def main():
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
@ -768,6 +806,15 @@ def main():
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
|
||||
logging.info(f"lm filename: {lm_filename}")
|
||||
ngram_lm = NgramLm(
|
||||
str(params.lang_dir / lm_filename),
|
||||
backoff_id=params.backoff_id,
|
||||
is_binary=False,
|
||||
)
|
||||
logging.info(f"num states: {ngram_lm.lm.num_states}")
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
@ -812,6 +859,8 @@ def main():
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=params.ngram_lm_scale,
|
||||
)
|
||||
|
||||
save_results(
|
||||
|
@ -42,6 +42,11 @@ import argparse
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from icefall import is_module_available
|
||||
|
||||
if not is_module_available("onnxruntime"):
|
||||
raise ValueError("Please 'pip install onnxruntime' first.")
|
||||
|
||||
import onnxruntime as ort
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
|
@ -91,6 +91,22 @@ Usage:
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
|
||||
To evaluate symbol delay, you should:
|
||||
(1) Generate cuts with word-time alignments:
|
||||
./local/add_alignment_librispeech.py \
|
||||
--alignments-dir data/alignment \
|
||||
--cuts-in-dir data/fbank \
|
||||
--cuts-out-dir data/fbank_ali
|
||||
(2) Set the argument "--manifest-dir data/fbank_ali" while decoding.
|
||||
For example:
|
||||
./lstm_transducer_stateless3/decode.py \
|
||||
--epoch 40 \
|
||||
--avg 20 \
|
||||
--exp-dir ./lstm_transducer_stateless3/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search \
|
||||
--manifest-dir data/fbank_ali
|
||||
"""
|
||||
|
||||
|
||||
@ -127,10 +143,12 @@ from icefall.checkpoint import (
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
DecodingResults,
|
||||
parse_hyp_and_timestamp,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
store_transcripts_and_timestamps,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
write_error_stats_with_timestamps,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
@ -314,7 +332,7 @@ def decode_one_batch(
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
@ -322,9 +340,11 @@ def decode_one_batch(
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
- value: It is a tuple. `len(value[0])` and `len(value[1])` are both
|
||||
equal to the batch size. `value[0][i]` and `value[1][i]`
|
||||
are the decoding result and timestamps for the i-th utterance
|
||||
in the given batch respectively.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
@ -343,8 +363,8 @@ def decode_one_batch(
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
Return the decoding result and timestamps. See above description for the
|
||||
format of the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
@ -370,10 +390,8 @@ def decode_one_batch(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
res = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
@ -381,11 +399,10 @@ def decode_one_batch(
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
return_timestamps=True,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
hyp_tokens = fast_beam_search_nbest_LG(
|
||||
res = fast_beam_search_nbest_LG(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
@ -395,11 +412,10 @@ def decode_one_batch(
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
return_timestamps=True,
|
||||
)
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
elif params.decoding_method == "fast_beam_search_nbest":
|
||||
hyp_tokens = fast_beam_search_nbest(
|
||||
res = fast_beam_search_nbest(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
@ -409,11 +425,10 @@ def decode_one_batch(
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
return_timestamps=True,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||
res = fast_beam_search_nbest_oracle(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
@ -424,56 +439,67 @@ def decode_one_batch(
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=sp.encode(supervisions["text"]),
|
||||
nbest_scale=params.nbest_scale,
|
||||
return_timestamps=True,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
hyp_tokens = greedy_search_batch(
|
||||
res = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
return_timestamps=True,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
res = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
return_timestamps=True,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
tokens = []
|
||||
timestamps = []
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
res = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
return_timestamps=True,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
res = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
return_timestamps=True,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
tokens.extend(res.tokens)
|
||||
timestamps.extend(res.timestamps)
|
||||
res = DecodingResults(tokens=tokens, timestamps=timestamps)
|
||||
|
||||
hyps, timestamps = parse_hyp_and_timestamp(
|
||||
decoding_method=params.decoding_method,
|
||||
res=res,
|
||||
sp=sp,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
frame_shift_ms=params.frame_shift_ms,
|
||||
word_table=word_table,
|
||||
)
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
return {"greedy_search": (hyps, timestamps)}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
@ -484,9 +510,9 @@ def decode_one_batch(
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
return {key: (hyps, timestamps)}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
return {f"beam_size_{params.beam_size}": (hyps, timestamps)}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
@ -496,7 +522,9 @@ def decode_dataset(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
) -> Dict[
|
||||
str, List[Tuple[str, List[str], List[str], List[float], List[float]]]
|
||||
]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
@ -517,9 +545,12 @@ def decode_dataset(
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
Its value is a list of tuples. Each tuple contains five elements:
|
||||
- cut_id
|
||||
- reference transcript
|
||||
- predicted result
|
||||
- timestamp of reference transcript
|
||||
- timestamp of predicted result
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
@ -538,6 +569,18 @@ def decode_dataset(
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
timestamps_ref = []
|
||||
for cut in batch["supervisions"]["cut"]:
|
||||
for s in cut.supervisions:
|
||||
time = []
|
||||
if s.alignment is not None and "word" in s.alignment:
|
||||
time = [
|
||||
aliword.start
|
||||
for aliword in s.alignment["word"]
|
||||
if aliword.symbol != ""
|
||||
]
|
||||
timestamps_ref.append(time)
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -547,12 +590,18 @@ def decode_dataset(
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
for name, (hyps, timestamps_hyp) in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
assert len(hyps) == len(texts) and len(timestamps_hyp) == len(
|
||||
timestamps_ref
|
||||
)
|
||||
for cut_id, hyp_words, ref_text, time_hyp, time_ref in zip(
|
||||
cut_ids, hyps, texts, timestamps_hyp, timestamps_ref
|
||||
):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
this_batch.append(
|
||||
(cut_id, ref_words, hyp_words, time_ref, time_hyp)
|
||||
)
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
@ -570,15 +619,19 @@ def decode_dataset(
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
results_dict: Dict[
|
||||
str,
|
||||
List[Tuple[List[str], List[str], List[str], List[float], List[float]]],
|
||||
],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
test_set_delays = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
store_transcripts_and_timestamps(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
@ -587,10 +640,11 @@ def save_results(
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
wer, mean_delay, var_delay = write_error_stats_with_timestamps(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
test_set_delays[key] = (mean_delay, var_delay)
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
@ -604,6 +658,19 @@ def save_results(
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0])
|
||||
delays_info = (
|
||||
params.res_dir
|
||||
/ f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(delays_info, "w") as f:
|
||||
print("settings\tsymbol-delay", file=f)
|
||||
for key, val in test_set_delays:
|
||||
print(
|
||||
"{}\tmean: {}s, variance: {}".format(key, val[0], val[1]),
|
||||
file=f,
|
||||
)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
@ -611,6 +678,15 @@ def save_results(
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
s = "\nFor {}, symbol-delay of different settings are:\n".format(
|
||||
test_set_name
|
||||
)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_delays:
|
||||
s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
|
@ -377,6 +377,7 @@ def get_params() -> AttributeDict:
|
||||
"""
|
||||
params = AttributeDict(
|
||||
{
|
||||
"frame_shift_ms": 10.0,
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
"best_train_epoch": -1,
|
||||
|
@ -511,7 +511,7 @@ def decode_dataset(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[str, Tuple[str, List[str], List[str]]]]:
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
@ -585,7 +585,7 @@ def decode_dataset(
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[str, Tuple[str, List[str], List[str]]]],
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
|
@ -16,15 +16,22 @@
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from model import Transducer
|
||||
|
||||
from icefall import NgramLm, NgramLmStateCost
|
||||
from icefall.decode import Nbest, one_best_decoding
|
||||
from icefall.utils import add_eos, add_sos, get_texts
|
||||
from icefall.utils import (
|
||||
DecodingResults,
|
||||
add_eos,
|
||||
add_sos,
|
||||
get_texts,
|
||||
get_texts_with_timestamp,
|
||||
)
|
||||
|
||||
|
||||
def fast_beam_search_one_best(
|
||||
@ -36,7 +43,8 @@ def fast_beam_search_one_best(
|
||||
max_states: int,
|
||||
max_contexts: int,
|
||||
temperature: float = 1.0,
|
||||
) -> List[List[int]]:
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
A lattice is first obtained using fast beam search, and then
|
||||
@ -60,8 +68,12 @@ def fast_beam_search_one_best(
|
||||
Max contexts pre stream per frame.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
"""
|
||||
lattice = fast_beam_search(
|
||||
model=model,
|
||||
@ -75,8 +87,11 @@ def fast_beam_search_one_best(
|
||||
)
|
||||
|
||||
best_path = one_best_decoding(lattice)
|
||||
hyps = get_texts(best_path)
|
||||
return hyps
|
||||
|
||||
if not return_timestamps:
|
||||
return get_texts(best_path)
|
||||
else:
|
||||
return get_texts_with_timestamp(best_path)
|
||||
|
||||
|
||||
def fast_beam_search_nbest_LG(
|
||||
@ -91,7 +106,8 @@ def fast_beam_search_nbest_LG(
|
||||
nbest_scale: float = 0.5,
|
||||
use_double_scores: bool = True,
|
||||
temperature: float = 1.0,
|
||||
) -> List[List[int]]:
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
The process to get the results is:
|
||||
@ -128,8 +144,12 @@ def fast_beam_search_nbest_LG(
|
||||
single precision.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
"""
|
||||
lattice = fast_beam_search(
|
||||
model=model,
|
||||
@ -194,9 +214,10 @@ def fast_beam_search_nbest_LG(
|
||||
best_hyp_indexes = ragged_tot_scores.argmax()
|
||||
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
|
||||
return hyps
|
||||
if not return_timestamps:
|
||||
return get_texts(best_path)
|
||||
else:
|
||||
return get_texts_with_timestamp(best_path)
|
||||
|
||||
|
||||
def fast_beam_search_nbest(
|
||||
@ -211,7 +232,8 @@ def fast_beam_search_nbest(
|
||||
nbest_scale: float = 0.5,
|
||||
use_double_scores: bool = True,
|
||||
temperature: float = 1.0,
|
||||
) -> List[List[int]]:
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
The process to get the results is:
|
||||
@ -248,8 +270,12 @@ def fast_beam_search_nbest(
|
||||
single precision.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
"""
|
||||
lattice = fast_beam_search(
|
||||
model=model,
|
||||
@ -278,9 +304,10 @@ def fast_beam_search_nbest(
|
||||
|
||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
|
||||
return hyps
|
||||
if not return_timestamps:
|
||||
return get_texts(best_path)
|
||||
else:
|
||||
return get_texts_with_timestamp(best_path)
|
||||
|
||||
|
||||
def fast_beam_search_nbest_oracle(
|
||||
@ -296,7 +323,8 @@ def fast_beam_search_nbest_oracle(
|
||||
use_double_scores: bool = True,
|
||||
nbest_scale: float = 0.5,
|
||||
temperature: float = 1.0,
|
||||
) -> List[List[int]]:
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
A lattice is first obtained using fast beam search, and then
|
||||
@ -337,8 +365,12 @@ def fast_beam_search_nbest_oracle(
|
||||
yields more unique paths.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
"""
|
||||
lattice = fast_beam_search(
|
||||
model=model,
|
||||
@ -377,8 +409,10 @@ def fast_beam_search_nbest_oracle(
|
||||
|
||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
return hyps
|
||||
if not return_timestamps:
|
||||
return get_texts(best_path)
|
||||
else:
|
||||
return get_texts_with_timestamp(best_path)
|
||||
|
||||
|
||||
def fast_beam_search(
|
||||
@ -468,8 +502,11 @@ def fast_beam_search(
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
|
||||
) -> List[int]:
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
max_sym_per_frame: int,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[int], DecodingResults]:
|
||||
"""Greedy search for a single utterance.
|
||||
Args:
|
||||
model:
|
||||
@ -479,8 +516,12 @@ def greedy_search(
|
||||
max_sym_per_frame:
|
||||
Maximum number of symbols per frame. If it is set to 0, the WER
|
||||
would be 100%.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
@ -506,6 +547,10 @@ def greedy_search(
|
||||
t = 0
|
||||
hyp = [blank_id] * context_size
|
||||
|
||||
# timestamp[i] is the frame index after subsampling
|
||||
# on which hyp[i] is decoded
|
||||
timestamp = []
|
||||
|
||||
# Maximum symbols per utterance.
|
||||
max_sym_per_utt = 1000
|
||||
|
||||
@ -532,6 +577,7 @@ def greedy_search(
|
||||
y = logits.argmax().item()
|
||||
if y not in (blank_id, unk_id):
|
||||
hyp.append(y)
|
||||
timestamp.append(t)
|
||||
decoder_input = torch.tensor(
|
||||
[hyp[-context_size:]], device=device
|
||||
).reshape(1, context_size)
|
||||
@ -546,14 +592,21 @@ def greedy_search(
|
||||
t += 1
|
||||
hyp = hyp[context_size:] # remove blanks
|
||||
|
||||
return hyp
|
||||
if not return_timestamps:
|
||||
return hyp
|
||||
else:
|
||||
return DecodingResults(
|
||||
tokens=[hyp],
|
||||
timestamps=[timestamp],
|
||||
)
|
||||
|
||||
|
||||
def greedy_search_batch(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
) -> List[List[int]]:
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
model:
|
||||
@ -563,9 +616,12 @@ def greedy_search_batch(
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,), containing number of valid frames in
|
||||
encoder_out before padding.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
Return a list-of-list of token IDs containing the decoded results.
|
||||
len(ans) equals to encoder_out.size(0).
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
@ -590,6 +646,10 @@ def greedy_search_batch(
|
||||
|
||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||
|
||||
# timestamp[n][i] is the frame index after subsampling
|
||||
# on which hyp[n][i] is decoded
|
||||
timestamps = [[] for _ in range(N)]
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
hyps,
|
||||
device=device,
|
||||
@ -603,7 +663,7 @@ def greedy_search_batch(
|
||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||
|
||||
offset = 0
|
||||
for batch_size in batch_size_list:
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end]
|
||||
@ -625,6 +685,7 @@ def greedy_search_batch(
|
||||
for i, v in enumerate(y):
|
||||
if v not in (blank_id, unk_id):
|
||||
hyps[i].append(v)
|
||||
timestamps[i].append(t)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
@ -639,11 +700,19 @@ def greedy_search_batch(
|
||||
|
||||
sorted_ans = [h[context_size:] for h in hyps]
|
||||
ans = []
|
||||
ans_timestamps = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
ans_timestamps.append(timestamps[unsorted_indices[i]])
|
||||
|
||||
return ans
|
||||
if not return_timestamps:
|
||||
return ans
|
||||
else:
|
||||
return DecodingResults(
|
||||
tokens=ans,
|
||||
timestamps=ans_timestamps,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -656,6 +725,12 @@ class Hypothesis:
|
||||
# It contains only one entry.
|
||||
log_prob: torch.Tensor
|
||||
|
||||
# timestamp[i] is the frame index after subsampling
|
||||
# on which ys[i] is decoded
|
||||
timestamp: List[int]
|
||||
|
||||
state_cost: Optional[NgramLmStateCost] = None
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
"""Return a string representation of self.ys"""
|
||||
@ -803,7 +878,8 @@ def modified_beam_search(
|
||||
encoder_out_lens: torch.Tensor,
|
||||
beam: int = 4,
|
||||
temperature: float = 1.0,
|
||||
) -> List[List[int]]:
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||
|
||||
Args:
|
||||
@ -818,9 +894,12 @@ def modified_beam_search(
|
||||
Number of active paths during the beam search.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
Return a list-of-list of token IDs. ans[i] is the decoding results
|
||||
for the i-th utterance.
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
"""
|
||||
assert encoder_out.ndim == 3, encoder_out.shape
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
@ -848,6 +927,7 @@ def modified_beam_search(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
timestamp=[],
|
||||
)
|
||||
)
|
||||
|
||||
@ -855,7 +935,7 @@ def modified_beam_search(
|
||||
|
||||
offset = 0
|
||||
finalized_B = []
|
||||
for batch_size in batch_size_list:
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end]
|
||||
@ -933,30 +1013,44 @@ def modified_beam_search(
|
||||
|
||||
new_ys = hyp.ys[:]
|
||||
new_token = topk_token_indexes[k]
|
||||
new_timestamp = hyp.timestamp[:]
|
||||
if new_token not in (blank_id, unk_id):
|
||||
new_ys.append(new_token)
|
||||
new_timestamp.append(t)
|
||||
|
||||
new_log_prob = topk_log_probs[k]
|
||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||
new_hyp = Hypothesis(
|
||||
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
|
||||
)
|
||||
B[i].add(new_hyp)
|
||||
|
||||
B = B + finalized_B
|
||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||
|
||||
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||
sorted_timestamps = [h.timestamp for h in best_hyps]
|
||||
ans = []
|
||||
ans_timestamps = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
|
||||
|
||||
return ans
|
||||
if not return_timestamps:
|
||||
return ans
|
||||
else:
|
||||
return DecodingResults(
|
||||
tokens=ans,
|
||||
timestamps=ans_timestamps,
|
||||
)
|
||||
|
||||
|
||||
def _deprecated_modified_beam_search(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
beam: int = 4,
|
||||
) -> List[int]:
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[int], DecodingResults]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
It decodes only one utterance at a time. We keep it only for reference.
|
||||
@ -971,8 +1065,13 @@ def _deprecated_modified_beam_search(
|
||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||
beam:
|
||||
Beam size.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
"""
|
||||
|
||||
assert encoder_out.ndim == 3
|
||||
@ -992,6 +1091,7 @@ def _deprecated_modified_beam_search(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
timestamp=[],
|
||||
)
|
||||
)
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
@ -1050,17 +1150,24 @@ def _deprecated_modified_beam_search(
|
||||
for i in range(len(topk_hyp_indexes)):
|
||||
hyp = A[topk_hyp_indexes[i]]
|
||||
new_ys = hyp.ys[:]
|
||||
new_timestamp = hyp.timestamp[:]
|
||||
new_token = topk_token_indexes[i]
|
||||
if new_token not in (blank_id, unk_id):
|
||||
new_ys.append(new_token)
|
||||
new_timestamp.append(t)
|
||||
new_log_prob = topk_log_probs[i]
|
||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||
new_hyp = Hypothesis(
|
||||
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
|
||||
)
|
||||
B.add(new_hyp)
|
||||
|
||||
best_hyp = B.get_most_probable(length_norm=True)
|
||||
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
||||
|
||||
return ys
|
||||
if not return_timestamps:
|
||||
return ys
|
||||
else:
|
||||
return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp])
|
||||
|
||||
|
||||
def beam_search(
|
||||
@ -1068,7 +1175,8 @@ def beam_search(
|
||||
encoder_out: torch.Tensor,
|
||||
beam: int = 4,
|
||||
temperature: float = 1.0,
|
||||
) -> List[int]:
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[int], DecodingResults]:
|
||||
"""
|
||||
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
|
||||
|
||||
@ -1083,8 +1191,13 @@ def beam_search(
|
||||
Beam size.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
@ -1111,7 +1224,7 @@ def beam_search(
|
||||
t = 0
|
||||
|
||||
B = HypothesisList()
|
||||
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0))
|
||||
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[]))
|
||||
|
||||
max_sym_per_utt = 20000
|
||||
|
||||
@ -1172,7 +1285,13 @@ def beam_search(
|
||||
new_y_star_log_prob = y_star.log_prob + skip_log_prob
|
||||
|
||||
# ys[:] returns a copy of ys
|
||||
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
|
||||
B.add(
|
||||
Hypothesis(
|
||||
ys=y_star.ys[:],
|
||||
log_prob=new_y_star_log_prob,
|
||||
timestamp=y_star.timestamp[:],
|
||||
)
|
||||
)
|
||||
|
||||
# Second, process other non-blank labels
|
||||
values, indices = log_prob.topk(beam + 1)
|
||||
@ -1181,7 +1300,14 @@ def beam_search(
|
||||
continue
|
||||
new_ys = y_star.ys + [i]
|
||||
new_log_prob = y_star.log_prob + v
|
||||
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
|
||||
new_timestamp = y_star.timestamp + [t]
|
||||
A.add(
|
||||
Hypothesis(
|
||||
ys=new_ys,
|
||||
log_prob=new_log_prob,
|
||||
timestamp=new_timestamp,
|
||||
)
|
||||
)
|
||||
|
||||
# Check whether B contains more than "beam" elements more probable
|
||||
# than the most probable in A
|
||||
@ -1197,7 +1323,11 @@ def beam_search(
|
||||
|
||||
best_hyp = B.get_most_probable(length_norm=True)
|
||||
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
||||
return ys
|
||||
|
||||
if not return_timestamps:
|
||||
return ys
|
||||
else:
|
||||
return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp])
|
||||
|
||||
|
||||
def fast_beam_search_with_nbest_rescoring(
|
||||
@ -1217,7 +1347,8 @@ def fast_beam_search_with_nbest_rescoring(
|
||||
use_double_scores: bool = True,
|
||||
nbest_scale: float = 0.5,
|
||||
temperature: float = 1.0,
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
return_timestamps: bool = False,
|
||||
) -> Dict[str, Union[List[List[int]], DecodingResults]]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
A lattice is first obtained using fast beam search, num_path are selected
|
||||
and rescored using a given language model. The shortest path within the
|
||||
@ -1259,10 +1390,13 @@ def fast_beam_search_with_nbest_rescoring(
|
||||
yields more unique paths.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
Return the decoded result in a dict, where the key has the form
|
||||
'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the
|
||||
ngram LM scale value used during decoding, i.e., 0.1.
|
||||
'ngram_lm_scale_xx' and the value is the decoded results
|
||||
optionally with timestamps. `xx` is the ngram LM scale value
|
||||
used during decoding, i.e., 0.1.
|
||||
"""
|
||||
lattice = fast_beam_search(
|
||||
model=model,
|
||||
@ -1340,16 +1474,18 @@ def fast_beam_search_with_nbest_rescoring(
|
||||
log_semiring=False,
|
||||
)
|
||||
|
||||
ans: Dict[str, List[List[int]]] = {}
|
||||
ans: Dict[str, Union[List[List[int]], DecodingResults]] = {}
|
||||
for s in ngram_lm_scale_list:
|
||||
key = f"ngram_lm_scale_{s}"
|
||||
tot_scores = am_scores.values + s * ngram_lm_scores
|
||||
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||
max_indexes = ragged_tot_scores.argmax()
|
||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||
hyps = get_texts(best_path)
|
||||
|
||||
ans[key] = hyps
|
||||
if not return_timestamps:
|
||||
ans[key] = get_texts(best_path)
|
||||
else:
|
||||
ans[key] = get_texts_with_timestamp(best_path)
|
||||
|
||||
return ans
|
||||
|
||||
@ -1373,7 +1509,8 @@ def fast_beam_search_with_nbest_rnn_rescoring(
|
||||
use_double_scores: bool = True,
|
||||
nbest_scale: float = 0.5,
|
||||
temperature: float = 1.0,
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
return_timestamps: bool = False,
|
||||
) -> Dict[str, Union[List[List[int]], DecodingResults]]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
A lattice is first obtained using fast beam search, num_path are selected
|
||||
and rescored using a given language model and a rnn-lm.
|
||||
@ -1419,10 +1556,13 @@ def fast_beam_search_with_nbest_rnn_rescoring(
|
||||
yields more unique paths.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
Return the decoded result in a dict, where the key has the form
|
||||
'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the
|
||||
ngram LM scale value used during decoding, i.e., 0.1.
|
||||
'ngram_lm_scale_xx' and the value is the decoded results
|
||||
optionally with timestamps. `xx` is the ngram LM scale value
|
||||
used during decoding, i.e., 0.1.
|
||||
"""
|
||||
lattice = fast_beam_search(
|
||||
model=model,
|
||||
@ -1534,8 +1674,180 @@ def fast_beam_search_with_nbest_rnn_rescoring(
|
||||
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||
max_indexes = ragged_tot_scores.argmax()
|
||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||
hyps = get_texts(best_path)
|
||||
|
||||
ans[key] = hyps
|
||||
if not return_timestamps:
|
||||
ans[key] = get_texts(best_path)
|
||||
else:
|
||||
ans[key] = get_texts_with_timestamp(best_path)
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
def modified_beam_search_ngram_rescoring(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ngram_lm: NgramLm,
|
||||
ngram_lm_scale: float,
|
||||
beam: int = 4,
|
||||
temperature: float = 1.0,
|
||||
) -> List[List[int]]:
|
||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, C).
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,), containing number of valid frames in
|
||||
encoder_out before padding.
|
||||
beam:
|
||||
Number of active paths during the beam search.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
Returns:
|
||||
Return a list-of-list of token IDs. ans[i] is the decoding results
|
||||
for the i-th utterance.
|
||||
"""
|
||||
assert encoder_out.ndim == 3, encoder_out.shape
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
unk_id = getattr(model, "unk_id", blank_id)
|
||||
context_size = model.decoder.context_size
|
||||
device = next(model.parameters()).device
|
||||
lm_scale = ngram_lm_scale
|
||||
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
B = [HypothesisList() for _ in range(N)]
|
||||
for i in range(N):
|
||||
B[i].add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
state_cost=NgramLmStateCost(ngram_lm),
|
||||
)
|
||||
)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||
|
||||
offset = 0
|
||||
finalized_B = []
|
||||
for batch_size in batch_size_list:
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end]
|
||||
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
||||
offset = end
|
||||
|
||||
finalized_B = B[batch_size:] + finalized_B
|
||||
B = B[:batch_size]
|
||||
|
||||
hyps_shape = get_hyps_shape(B).to(device)
|
||||
|
||||
A = [list(b) for b in B]
|
||||
B = [HypothesisList() for _ in range(batch_size)]
|
||||
|
||||
ys_log_probs = torch.cat(
|
||||
[
|
||||
hyp.log_prob.reshape(1, 1) + hyp.state_cost.lm_score * lm_scale
|
||||
for hyps in A
|
||||
for hyp in hyps
|
||||
]
|
||||
) # (num_hyps, 1)
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (num_hyps, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
|
||||
|
||||
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
||||
# as index, so we use `to(torch.int64)` below.
|
||||
current_encoder_out = torch.index_select(
|
||||
current_encoder_out,
|
||||
dim=0,
|
||||
index=hyps_shape.row_ids(1).to(torch.int64),
|
||||
) # (num_hyps, 1, 1, encoder_out_dim)
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
project_input=False,
|
||||
) # (num_hyps, 1, 1, vocab_size)
|
||||
|
||||
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
|
||||
|
||||
log_probs = (logits / temperature).log_softmax(
|
||||
dim=-1
|
||||
) # (num_hyps, vocab_size)
|
||||
|
||||
log_probs.add_(ys_log_probs)
|
||||
vocab_size = log_probs.size(-1)
|
||||
log_probs = log_probs.reshape(-1)
|
||||
|
||||
row_splits = hyps_shape.row_splits(1) * vocab_size
|
||||
log_probs_shape = k2.ragged.create_ragged_shape2(
|
||||
row_splits=row_splits, cached_tot_size=log_probs.numel()
|
||||
)
|
||||
ragged_log_probs = k2.RaggedTensor(
|
||||
shape=log_probs_shape, value=log_probs
|
||||
)
|
||||
|
||||
for i in range(batch_size):
|
||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||
|
||||
for k in range(len(topk_hyp_indexes)):
|
||||
hyp_idx = topk_hyp_indexes[k]
|
||||
hyp = A[i][hyp_idx]
|
||||
|
||||
new_ys = hyp.ys[:]
|
||||
new_token = topk_token_indexes[k]
|
||||
if new_token not in (blank_id, unk_id):
|
||||
new_ys.append(new_token)
|
||||
state_cost = hyp.state_cost.forward_one_step(new_token)
|
||||
else:
|
||||
state_cost = hyp.state_cost
|
||||
|
||||
# We only keep AM scores in new_hyp.log_prob
|
||||
new_log_prob = (
|
||||
topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale
|
||||
)
|
||||
|
||||
new_hyp = Hypothesis(
|
||||
ys=new_ys, log_prob=new_log_prob, state_cost=state_cost
|
||||
)
|
||||
B[i].add(new_hyp)
|
||||
|
||||
B = B + finalized_B
|
||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||
|
||||
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||
ans = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
|
||||
return ans
|
||||
|
@ -380,14 +380,13 @@ def decode_one_batch(
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
|
||||
if params.simulate_streaming:
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
|
@ -462,14 +462,13 @@ def decode_one_batch(
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
|
||||
if params.simulate_streaming:
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
|
@ -24,6 +24,11 @@ with the given torchscript model for the same input.
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from icefall import is_module_available
|
||||
|
||||
if not is_module_available("onnxruntime"):
|
||||
raise ValueError("Please 'pip install onnxruntime' first.")
|
||||
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
|
||||
|
@ -21,6 +21,11 @@ This file is to test that models can be exported to onnx.
|
||||
"""
|
||||
import os
|
||||
|
||||
from icefall import is_module_available
|
||||
|
||||
if not is_module_available("onnxruntime"):
|
||||
raise ValueError("Please 'pip install onnxruntime' first.")
|
||||
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
from conformer import (
|
||||
|
@ -328,6 +328,7 @@ def get_parser():
|
||||
help="The probability to select a batch from the GigaSpeech dataset",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -106,6 +106,22 @@ Usage:
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
|
||||
To evaluate symbol delay, you should:
|
||||
(1) Generate cuts with word-time alignments:
|
||||
./local/add_alignment_librispeech.py \
|
||||
--alignments-dir data/alignment \
|
||||
--cuts-in-dir data/fbank \
|
||||
--cuts-out-dir data/fbank_ali
|
||||
(2) Set the argument "--manifest-dir data/fbank_ali" while decoding.
|
||||
For example:
|
||||
./pruned_transducer_stateless4/decode.py \
|
||||
--epoch 40 \
|
||||
--avg 20 \
|
||||
--exp-dir ./pruned_transducer_stateless4/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search \
|
||||
--manifest-dir data/fbank_ali
|
||||
"""
|
||||
|
||||
|
||||
@ -142,10 +158,12 @@ from icefall.checkpoint import (
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
DecodingResults,
|
||||
parse_hyp_and_timestamp,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
store_transcripts_and_timestamps,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
write_error_stats_with_timestamps,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
@ -318,7 +336,7 @@ def get_parser():
|
||||
"--left-context",
|
||||
type=int,
|
||||
default=64,
|
||||
help="left context can be seen during decoding (in frames after subsampling)",
|
||||
help="left context can be seen during decoding (in frames after subsampling)", # noqa
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -350,7 +368,7 @@ def decode_one_batch(
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
@ -358,9 +376,10 @@ def decode_one_batch(
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
- value: It is a tuple. `len(value[0])` and `len(value[1])` are both
|
||||
equal to the batch size. `value[0][i]` and `value[1][i]`
|
||||
are the decoding result and timestamps for the i-th utterance
|
||||
in the given batch respectively.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
@ -379,8 +398,8 @@ def decode_one_batch(
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
Return the decoding result and timestamps. See above description for the
|
||||
format of the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
@ -392,14 +411,13 @@ def decode_one_batch(
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
|
||||
if params.simulate_streaming:
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
@ -412,10 +430,8 @@ def decode_one_batch(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
res = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
@ -423,11 +439,10 @@ def decode_one_batch(
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
return_timestamps=True,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
hyp_tokens = fast_beam_search_nbest_LG(
|
||||
res = fast_beam_search_nbest_LG(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
@ -437,11 +452,10 @@ def decode_one_batch(
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
return_timestamps=True,
|
||||
)
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
elif params.decoding_method == "fast_beam_search_nbest":
|
||||
hyp_tokens = fast_beam_search_nbest(
|
||||
res = fast_beam_search_nbest(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
@ -451,11 +465,10 @@ def decode_one_batch(
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
return_timestamps=True,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||
res = fast_beam_search_nbest_oracle(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
@ -466,56 +479,67 @@ def decode_one_batch(
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=sp.encode(supervisions["text"]),
|
||||
nbest_scale=params.nbest_scale,
|
||||
return_timestamps=True,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
hyp_tokens = greedy_search_batch(
|
||||
res = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
return_timestamps=True,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
res = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
return_timestamps=True,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
tokens = []
|
||||
timestamps = []
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
res = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
return_timestamps=True,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
res = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
return_timestamps=True,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
tokens.extend(res.tokens)
|
||||
timestamps.extend(res.timestamps)
|
||||
res = DecodingResults(tokens=tokens, timestamps=timestamps)
|
||||
|
||||
hyps, timestamps = parse_hyp_and_timestamp(
|
||||
decoding_method=params.decoding_method,
|
||||
res=res,
|
||||
sp=sp,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
frame_shift_ms=params.frame_shift_ms,
|
||||
word_table=word_table,
|
||||
)
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
return {"greedy_search": (hyps, timestamps)}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
@ -526,9 +550,9 @@ def decode_one_batch(
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
return {key: (hyps, timestamps)}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
return {f"beam_size_{params.beam_size}": (hyps, timestamps)}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
@ -538,7 +562,9 @@ def decode_dataset(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
) -> Dict[
|
||||
str, List[Tuple[str, List[str], List[str], List[float], List[float]]]
|
||||
]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
@ -559,9 +585,12 @@ def decode_dataset(
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
Its value is a list of tuples. Each tuple contains five elements:
|
||||
- cut_id
|
||||
- reference transcript
|
||||
- predicted result
|
||||
- timestamp of reference transcript
|
||||
- timestamp of predicted result
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
@ -580,6 +609,18 @@ def decode_dataset(
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
timestamps_ref = []
|
||||
for cut in batch["supervisions"]["cut"]:
|
||||
for s in cut.supervisions:
|
||||
time = []
|
||||
if s.alignment is not None and "word" in s.alignment:
|
||||
time = [
|
||||
aliword.start
|
||||
for aliword in s.alignment["word"]
|
||||
if aliword.symbol != ""
|
||||
]
|
||||
timestamps_ref.append(time)
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -589,12 +630,18 @@ def decode_dataset(
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
for name, (hyps, timestamps_hyp) in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
assert len(hyps) == len(texts) and len(timestamps_hyp) == len(
|
||||
timestamps_ref
|
||||
)
|
||||
for cut_id, hyp_words, ref_text, time_hyp, time_ref in zip(
|
||||
cut_ids, hyps, texts, timestamps_hyp, timestamps_ref
|
||||
):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
this_batch.append(
|
||||
(cut_id, ref_words, hyp_words, time_ref, time_hyp)
|
||||
)
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
@ -612,15 +659,19 @@ def decode_dataset(
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[str, Tuple[List[str], List[str]]]],
|
||||
results_dict: Dict[
|
||||
str,
|
||||
List[Tuple[List[str], List[str], List[str], List[float], List[float]]],
|
||||
],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
test_set_delays = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
store_transcripts_and_timestamps(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
@ -629,10 +680,11 @@ def save_results(
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
wer, mean_delay, var_delay = write_error_stats_with_timestamps(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
test_set_delays[key] = (mean_delay, var_delay)
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
@ -646,6 +698,19 @@ def save_results(
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0])
|
||||
delays_info = (
|
||||
params.res_dir
|
||||
/ f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(delays_info, "w") as f:
|
||||
print("settings\tsymbol-delay", file=f)
|
||||
for key, val in test_set_delays:
|
||||
print(
|
||||
"{}\tmean: {}s, variance: {}".format(key, val[0], val[1]),
|
||||
file=f,
|
||||
)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
@ -653,6 +718,15 @@ def save_results(
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
s = "\nFor {}, symbol-delay of different settings are:\n".format(
|
||||
test_set_name
|
||||
)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_delays:
|
||||
s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
|
@ -386,6 +386,7 @@ def get_params() -> AttributeDict:
|
||||
"""
|
||||
params = AttributeDict(
|
||||
{
|
||||
"frame_shift_ms": 10.0,
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
"best_train_epoch": -1,
|
||||
|
@ -378,14 +378,13 @@ def decode_one_batch(
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
|
||||
if params.simulate_streaming:
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
|
@ -21,7 +21,6 @@ import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from multi_quantization.prediction import JointCodebookLoss
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos
|
||||
@ -74,6 +73,14 @@ class Transducer(nn.Module):
|
||||
encoder_dim, vocab_size, initial_speed=0.5
|
||||
)
|
||||
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
|
||||
|
||||
from icefall import is_module_available
|
||||
|
||||
if not is_module_available("multi_quantization"):
|
||||
raise ValueError("Please 'pip install multi_quantization' first.")
|
||||
|
||||
from multi_quantization.prediction import JointCodebookLoss
|
||||
|
||||
if num_codebooks > 0:
|
||||
self.codebook_loss_net = JointCodebookLoss(
|
||||
predictor_channels=encoder_dim,
|
||||
|
@ -28,18 +28,21 @@ from typing import List, Tuple
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import multi_quantization as quantization
|
||||
|
||||
from icefall import is_module_available
|
||||
|
||||
if not is_module_available("multi_quantization"):
|
||||
raise ValueError("Please 'pip install multi_quantization' first.")
|
||||
|
||||
import multi_quantization as quantization
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from hubert_xlarge import HubertXlargeFineTuned
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
)
|
||||
from lhotse import CutSet, load_manifest
|
||||
from lhotse.cut import MonoCut
|
||||
from lhotse.features.io import NumpyHdf5Writer
|
||||
|
||||
from icefall.utils import AttributeDict, setup_logger
|
||||
|
||||
|
||||
class CodebookIndexExtractor:
|
||||
"""
|
||||
|
@ -327,7 +327,7 @@ def decode_dataset(
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[str, Tuple[List[str], List[str]]]],
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
|
@ -40,6 +40,11 @@ https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_s
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from icefall import is_module_available
|
||||
|
||||
if not is_module_available("onnxruntime"):
|
||||
raise ValueError("Please 'pip install onnxruntime' first.")
|
||||
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
|
||||
|
@ -49,6 +49,12 @@ from typing import List
|
||||
import k2
|
||||
import kaldifeat
|
||||
import numpy as np
|
||||
|
||||
from icefall import is_module_available
|
||||
|
||||
if not is_module_available("onnxruntime"):
|
||||
raise ValueError("Please 'pip install onnxruntime' first.")
|
||||
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torchaudio
|
||||
|
@ -50,6 +50,7 @@ from .utils import (
|
||||
get_executor,
|
||||
get_texts,
|
||||
is_jit_tracing,
|
||||
is_module_available,
|
||||
l1_norm,
|
||||
l2_norm,
|
||||
linf_norm,
|
||||
@ -65,3 +66,5 @@ from .utils import (
|
||||
subsequent_chunk_mask,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
from .ngram_lm import NgramLm, NgramLmStateCost
|
||||
|
171
icefall/ngram_lm.py
Normal file
171
icefall/ngram_lm.py
Normal file
@ -0,0 +1,171 @@
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from icefall.utils import is_module_available
|
||||
|
||||
|
||||
class NgramLm:
|
||||
def __init__(
|
||||
self,
|
||||
fst_filename: str,
|
||||
backoff_id: int,
|
||||
is_binary: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
fst_filename:
|
||||
Path to the FST.
|
||||
backoff_id:
|
||||
ID of the backoff symbol.
|
||||
is_binary:
|
||||
True if the given file is a binary FST.
|
||||
"""
|
||||
if not is_module_available("kaldifst"):
|
||||
raise ValueError("Please 'pip install kaldifst' first.")
|
||||
|
||||
import kaldifst
|
||||
|
||||
if is_binary:
|
||||
lm = kaldifst.StdVectorFst.read(fst_filename)
|
||||
else:
|
||||
with open(fst_filename, "r") as f:
|
||||
lm = kaldifst.compile(f.read(), acceptor=False)
|
||||
|
||||
if not lm.is_ilabel_sorted:
|
||||
kaldifst.arcsort(lm, sort_type="ilabel")
|
||||
|
||||
self.lm = lm
|
||||
self.backoff_id = backoff_id
|
||||
|
||||
def _process_backoff_arcs(
|
||||
self,
|
||||
state: int,
|
||||
cost: float,
|
||||
) -> List[Tuple[int, float]]:
|
||||
"""Similar to ProcessNonemitting() from Kaldi, this function
|
||||
returns the list of states reachable from the given state via
|
||||
backoff arcs.
|
||||
|
||||
Args:
|
||||
state:
|
||||
The input state.
|
||||
cost:
|
||||
The cost of reaching the given state from the start state.
|
||||
Returns:
|
||||
Return a list, where each element contains a tuple with two entries:
|
||||
- next_state
|
||||
- cost of next_state
|
||||
If there is no backoff arc leaving the input state, then return
|
||||
an empty list.
|
||||
"""
|
||||
ans = []
|
||||
|
||||
next_state, next_cost = self._get_next_state_and_cost_without_backoff(
|
||||
state=state,
|
||||
label=self.backoff_id,
|
||||
)
|
||||
if next_state is None:
|
||||
return ans
|
||||
ans.append((next_state, next_cost + cost))
|
||||
ans += self._process_backoff_arcs(next_state, next_cost + cost)
|
||||
return ans
|
||||
|
||||
def _get_next_state_and_cost_without_backoff(
|
||||
self, state: int, label: int
|
||||
) -> Tuple[int, float]:
|
||||
"""TODO: Add doc."""
|
||||
import kaldifst
|
||||
|
||||
arc_iter = kaldifst.ArcIterator(self.lm, state)
|
||||
num_arcs = self.lm.num_arcs(state)
|
||||
|
||||
# The LM is arc sorted by ilabel, so we use binary search below.
|
||||
left = 0
|
||||
right = num_arcs - 1
|
||||
while left <= right:
|
||||
mid = (left + right) // 2
|
||||
arc_iter.seek(mid)
|
||||
arc = arc_iter.value
|
||||
if arc.ilabel < label:
|
||||
left = mid + 1
|
||||
elif arc.ilabel > label:
|
||||
right = mid - 1
|
||||
else:
|
||||
return arc.nextstate, arc.weight.value
|
||||
|
||||
return None, None
|
||||
|
||||
def get_next_state_and_cost(
|
||||
self,
|
||||
state: int,
|
||||
label: int,
|
||||
) -> Tuple[List[int], List[float]]:
|
||||
states = [state]
|
||||
costs = [0]
|
||||
|
||||
extra_states_costs = self._process_backoff_arcs(
|
||||
state=state,
|
||||
cost=0,
|
||||
)
|
||||
|
||||
for s, c in extra_states_costs:
|
||||
states.append(s)
|
||||
costs.append(c)
|
||||
|
||||
next_states = []
|
||||
next_costs = []
|
||||
for s, c in zip(states, costs):
|
||||
ns, nc = self._get_next_state_and_cost_without_backoff(s, label)
|
||||
if ns:
|
||||
next_states.append(ns)
|
||||
next_costs.append(c + nc)
|
||||
|
||||
return next_states, next_costs
|
||||
|
||||
|
||||
class NgramLmStateCost:
|
||||
def __init__(self, ngram_lm: NgramLm, state_cost: Optional[dict] = None):
|
||||
assert ngram_lm.lm.start == 0, ngram_lm.lm.start
|
||||
self.ngram_lm = ngram_lm
|
||||
if state_cost is not None:
|
||||
self.state_cost = state_cost
|
||||
else:
|
||||
self.state_cost = defaultdict(lambda: float("inf"))
|
||||
|
||||
# At the very beginning, we are at the start state with cost 0
|
||||
self.state_cost[0] = 0.0
|
||||
|
||||
def forward_one_step(self, label: int) -> "NgramLmStateCost":
|
||||
state_cost = defaultdict(lambda: float("inf"))
|
||||
for s, c in self.state_cost.items():
|
||||
next_states, next_costs = self.ngram_lm.get_next_state_and_cost(
|
||||
s,
|
||||
label,
|
||||
)
|
||||
for ns, nc in zip(next_states, next_costs):
|
||||
state_cost[ns] = min(state_cost[ns], c + nc)
|
||||
|
||||
return NgramLmStateCost(ngram_lm=self.ngram_lm, state_cost=state_cost)
|
||||
|
||||
@property
|
||||
def lm_score(self) -> float:
|
||||
if len(self.state_cost) == 0:
|
||||
return float("-inf")
|
||||
|
||||
return -1 * min(self.state_cost.values())
|
460
icefall/utils.py
460
icefall/utils.py
@ -24,9 +24,10 @@ import re
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, TextIO, Tuple, Union
|
||||
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
|
||||
|
||||
import k2
|
||||
import k2.version
|
||||
@ -248,6 +249,86 @@ def get_texts(
|
||||
return aux_labels.tolist()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodingResults:
|
||||
# Decoded token IDs for each utterance in the batch
|
||||
tokens: List[List[int]]
|
||||
|
||||
# timestamps[i][k] contains the frame number on which tokens[i][k]
|
||||
# is decoded
|
||||
timestamps: List[List[int]]
|
||||
|
||||
# hyps[i] is the recognition results, i.e., word IDs
|
||||
# for the i-th utterance with fast_beam_search_nbest_LG.
|
||||
hyps: Union[List[List[int]], k2.RaggedTensor] = None
|
||||
|
||||
|
||||
def get_tokens_and_timestamps(labels: List[int]) -> Tuple[List[int], List[int]]:
|
||||
tokens = []
|
||||
timestamps = []
|
||||
for i, v in enumerate(labels):
|
||||
if v != 0:
|
||||
tokens.append(v)
|
||||
timestamps.append(i)
|
||||
|
||||
return tokens, timestamps
|
||||
|
||||
|
||||
def get_texts_with_timestamp(
|
||||
best_paths: k2.Fsa, return_ragged: bool = False
|
||||
) -> DecodingResults:
|
||||
"""Extract the texts (as word IDs) and timestamps from the best-path FSAs.
|
||||
Args:
|
||||
best_paths:
|
||||
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
||||
containing multiple FSAs, which is expected to be the result
|
||||
of k2.shortest_path (otherwise the returned values won't
|
||||
be meaningful).
|
||||
return_ragged:
|
||||
True to return a ragged tensor with two axes [utt][word_id].
|
||||
False to return a list-of-list word IDs.
|
||||
Returns:
|
||||
Returns a list of lists of int, containing the label sequences we
|
||||
decoded.
|
||||
"""
|
||||
if isinstance(best_paths.aux_labels, k2.RaggedTensor):
|
||||
# remove 0's and -1's.
|
||||
aux_labels = best_paths.aux_labels.remove_values_leq(0)
|
||||
# TODO: change arcs.shape() to arcs.shape
|
||||
aux_shape = best_paths.arcs.shape().compose(aux_labels.shape)
|
||||
|
||||
# remove the states and arcs axes.
|
||||
aux_shape = aux_shape.remove_axis(1)
|
||||
aux_shape = aux_shape.remove_axis(1)
|
||||
aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values)
|
||||
else:
|
||||
# remove axis corresponding to states.
|
||||
aux_shape = best_paths.arcs.shape().remove_axis(1)
|
||||
aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels)
|
||||
# remove 0's and -1's.
|
||||
aux_labels = aux_labels.remove_values_leq(0)
|
||||
|
||||
assert aux_labels.num_axes == 2
|
||||
|
||||
labels_shape = best_paths.arcs.shape().remove_axis(1)
|
||||
labels_list = k2.RaggedTensor(
|
||||
labels_shape, best_paths.labels.contiguous()
|
||||
).tolist()
|
||||
|
||||
tokens = []
|
||||
timestamps = []
|
||||
for labels in labels_list:
|
||||
token, time = get_tokens_and_timestamps(labels[:-1])
|
||||
tokens.append(token)
|
||||
timestamps.append(time)
|
||||
|
||||
return DecodingResults(
|
||||
tokens=tokens,
|
||||
timestamps=timestamps,
|
||||
hyps=aux_labels if return_ragged else aux_labels.tolist(),
|
||||
)
|
||||
|
||||
|
||||
def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
|
||||
"""Extract labels or aux_labels from the best-path FSAs.
|
||||
|
||||
@ -352,6 +433,33 @@ def store_transcripts(
|
||||
print(f"{cut_id}:\thyp={hyp}", file=f)
|
||||
|
||||
|
||||
def store_transcripts_and_timestamps(
|
||||
filename: Pathlike,
|
||||
texts: Iterable[Tuple[str, List[str], List[str], List[float], List[float]]],
|
||||
) -> None:
|
||||
"""Save predicted results and reference transcripts as well as their timestamps
|
||||
to a file.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
File to save the results to.
|
||||
texts:
|
||||
An iterable of tuples. The first element is the cur_id, the second is
|
||||
the reference transcript and the third element is the predicted result.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
with open(filename, "w") as f:
|
||||
for cut_id, ref, hyp, time_ref, time_hyp in texts:
|
||||
print(f"{cut_id}:\tref={ref}", file=f)
|
||||
print(f"{cut_id}:\thyp={hyp}", file=f)
|
||||
if len(time_ref) > 0:
|
||||
s = "[" + ", ".join(["%0.3f" % i for i in time_ref]) + "]"
|
||||
print(f"{cut_id}:\ttimestamp_ref={s}", file=f)
|
||||
s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]"
|
||||
print(f"{cut_id}:\ttimestamp_hyp={s}", file=f)
|
||||
|
||||
|
||||
def write_error_stats(
|
||||
f: TextIO,
|
||||
test_set_name: str,
|
||||
@ -519,6 +627,211 @@ def write_error_stats(
|
||||
return float(tot_err_rate)
|
||||
|
||||
|
||||
def write_error_stats_with_timestamps(
|
||||
f: TextIO,
|
||||
test_set_name: str,
|
||||
results: List[Tuple[str, List[str], List[str], List[float], List[float]]],
|
||||
enable_log: bool = True,
|
||||
) -> Tuple[float, float, float]:
|
||||
"""Write statistics based on predicted results and reference transcripts
|
||||
as well as their timestamps.
|
||||
|
||||
It will write the following to the given file:
|
||||
|
||||
- WER
|
||||
- number of insertions, deletions, substitutions, corrects and total
|
||||
reference words. For example::
|
||||
|
||||
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
|
||||
reference words (2337 correct)
|
||||
|
||||
- The difference between the reference transcript and predicted result.
|
||||
An instance is given below::
|
||||
|
||||
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
|
||||
|
||||
The above example shows that the reference word is `EDISON`,
|
||||
but it is predicted to `ADDISON` (a substitution error).
|
||||
|
||||
Another example is::
|
||||
|
||||
FOR THE FIRST DAY (SIR->*) I THINK
|
||||
|
||||
The reference word `SIR` is missing in the predicted
|
||||
results (a deletion error).
|
||||
results:
|
||||
An iterable of tuples. The first element is the cur_id, the second is
|
||||
the reference transcript and the third element is the predicted result.
|
||||
enable_log:
|
||||
If True, also print detailed WER to the console.
|
||||
Otherwise, it is written only to the given file.
|
||||
|
||||
Returns:
|
||||
Return total word error rate and mean delay.
|
||||
"""
|
||||
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
||||
ins: Dict[str, int] = defaultdict(int)
|
||||
dels: Dict[str, int] = defaultdict(int)
|
||||
|
||||
# `words` stores counts per word, as follows:
|
||||
# corr, ref_sub, hyp_sub, ins, dels
|
||||
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
||||
num_corr = 0
|
||||
ERR = "*"
|
||||
# Compute mean alignment delay on the correct words
|
||||
all_delay = []
|
||||
for cut_id, ref, hyp, time_ref, time_hyp in results:
|
||||
ali = kaldialign.align(ref, hyp, ERR)
|
||||
has_time_ref = len(time_ref) > 0
|
||||
if has_time_ref:
|
||||
# pointer to timestamp_hyp
|
||||
p_hyp = 0
|
||||
# pointer to timestamp_ref
|
||||
p_ref = 0
|
||||
for ref_word, hyp_word in ali:
|
||||
if ref_word == ERR:
|
||||
ins[hyp_word] += 1
|
||||
words[hyp_word][3] += 1
|
||||
if has_time_ref:
|
||||
p_hyp += 1
|
||||
elif hyp_word == ERR:
|
||||
dels[ref_word] += 1
|
||||
words[ref_word][4] += 1
|
||||
if has_time_ref:
|
||||
p_ref += 1
|
||||
elif hyp_word != ref_word:
|
||||
subs[(ref_word, hyp_word)] += 1
|
||||
words[ref_word][1] += 1
|
||||
words[hyp_word][2] += 1
|
||||
if has_time_ref:
|
||||
p_hyp += 1
|
||||
p_ref += 1
|
||||
else:
|
||||
words[ref_word][0] += 1
|
||||
num_corr += 1
|
||||
if has_time_ref:
|
||||
all_delay.append(time_hyp[p_hyp] - time_ref[p_ref])
|
||||
p_hyp += 1
|
||||
p_ref += 1
|
||||
if has_time_ref:
|
||||
assert p_hyp == len(hyp), (p_hyp, len(hyp))
|
||||
assert p_ref == len(ref), (p_ref, len(ref))
|
||||
|
||||
ref_len = sum([len(r) for _, r, _, _, _ in results])
|
||||
sub_errs = sum(subs.values())
|
||||
ins_errs = sum(ins.values())
|
||||
del_errs = sum(dels.values())
|
||||
tot_errs = sub_errs + ins_errs + del_errs
|
||||
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
||||
|
||||
mean_delay = "inf"
|
||||
var_delay = "inf"
|
||||
num_delay = len(all_delay)
|
||||
if num_delay > 0:
|
||||
mean_delay = sum(all_delay) / num_delay
|
||||
var_delay = sum([(i - mean_delay) ** 2 for i in all_delay]) / num_delay
|
||||
mean_delay = "%.3f" % mean_delay
|
||||
var_delay = "%.3f" % var_delay
|
||||
|
||||
if enable_log:
|
||||
logging.info(
|
||||
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
||||
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
||||
f"{del_errs} del, {sub_errs} sub ]"
|
||||
)
|
||||
logging.info(
|
||||
f"[{test_set_name}] %symbol-delay mean: {mean_delay}s, variance: {var_delay} " # noqa
|
||||
f"computed on {num_delay} correct words"
|
||||
)
|
||||
|
||||
print(f"%WER = {tot_err_rate}", file=f)
|
||||
print(
|
||||
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
|
||||
f"{sub_errs} substitutions, over {ref_len} reference "
|
||||
f"words ({num_corr} correct)",
|
||||
file=f,
|
||||
)
|
||||
print(
|
||||
"Search below for sections starting with PER-UTT DETAILS:, "
|
||||
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
|
||||
file=f,
|
||||
)
|
||||
|
||||
print("", file=f)
|
||||
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
||||
for cut_id, ref, hyp, _, _ in results:
|
||||
ali = kaldialign.align(ref, hyp, ERR)
|
||||
combine_successive_errors = True
|
||||
if combine_successive_errors:
|
||||
ali = [[[x], [y]] for x, y in ali]
|
||||
for i in range(len(ali) - 1):
|
||||
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
|
||||
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
|
||||
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
|
||||
ali[i] = [[], []]
|
||||
ali = [
|
||||
[
|
||||
list(filter(lambda a: a != ERR, x)),
|
||||
list(filter(lambda a: a != ERR, y)),
|
||||
]
|
||||
for x, y in ali
|
||||
]
|
||||
ali = list(filter(lambda x: x != [[], []], ali))
|
||||
ali = [
|
||||
[
|
||||
ERR if x == [] else " ".join(x),
|
||||
ERR if y == [] else " ".join(y),
|
||||
]
|
||||
for x, y in ali
|
||||
]
|
||||
|
||||
print(
|
||||
f"{cut_id}:\t"
|
||||
+ " ".join(
|
||||
(
|
||||
ref_word
|
||||
if ref_word == hyp_word
|
||||
else f"({ref_word}->{hyp_word})"
|
||||
for ref_word, hyp_word in ali
|
||||
)
|
||||
),
|
||||
file=f,
|
||||
)
|
||||
|
||||
print("", file=f)
|
||||
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
||||
|
||||
for count, (ref, hyp) in sorted(
|
||||
[(v, k) for k, v in subs.items()], reverse=True
|
||||
):
|
||||
print(f"{count} {ref} -> {hyp}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print("DELETIONS: count ref", file=f)
|
||||
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
|
||||
print(f"{count} {ref}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print("INSERTIONS: count hyp", file=f)
|
||||
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
|
||||
print(f"{count} {hyp}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print(
|
||||
"PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f
|
||||
)
|
||||
for _, word, counts in sorted(
|
||||
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
||||
):
|
||||
(corr, ref_sub, hyp_sub, ins, dels) = counts
|
||||
tot_errs = ref_sub + hyp_sub + ins + dels
|
||||
ref_count = corr + ref_sub + dels
|
||||
hyp_count = corr + hyp_sub + ins
|
||||
|
||||
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||
return float(tot_err_rate), float(mean_delay), float(var_delay)
|
||||
|
||||
|
||||
class MetricsTracker(collections.defaultdict):
|
||||
def __init__(self):
|
||||
# Passing the type 'int' to the base-class constructor
|
||||
@ -976,3 +1289,148 @@ def display_and_save_batch(
|
||||
y = sp.encode(supervisions["text"], out_type=int)
|
||||
num_tokens = sum(len(i) for i in y)
|
||||
logging.info(f"num tokens: {num_tokens}")
|
||||
|
||||
|
||||
def convert_timestamp(
|
||||
frames: List[int],
|
||||
subsampling_factor: int,
|
||||
frame_shift_ms: float = 10,
|
||||
) -> List[float]:
|
||||
"""Convert frame numbers to time (in seconds) given subsampling factor
|
||||
and frame shift (in milliseconds).
|
||||
|
||||
Args:
|
||||
frames:
|
||||
A list of frame numbers after subsampling.
|
||||
subsampling_factor:
|
||||
The subsampling factor of the model.
|
||||
frame_shift_ms:
|
||||
Frame shift in milliseconds between two contiguous frames.
|
||||
Return:
|
||||
Return the time in seconds corresponding to each given frame.
|
||||
"""
|
||||
frame_shift = frame_shift_ms / 1000.0
|
||||
time = []
|
||||
for f in frames:
|
||||
time.append(f * subsampling_factor * frame_shift)
|
||||
|
||||
return time
|
||||
|
||||
|
||||
def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]:
|
||||
"""
|
||||
Parse timestamp of each word.
|
||||
|
||||
Args:
|
||||
tokens:
|
||||
List of tokens.
|
||||
timestamp:
|
||||
List of timestamp of each token.
|
||||
|
||||
Returns:
|
||||
List of timestamp of each word.
|
||||
"""
|
||||
start_token = b"\xe2\x96\x81".decode() # '_'
|
||||
assert len(tokens) == len(timestamp)
|
||||
ans = []
|
||||
for i in range(len(tokens)):
|
||||
flag = False
|
||||
if i == 0 or tokens[i].startswith(start_token):
|
||||
flag = True
|
||||
if len(tokens[i]) == 1 and tokens[i].startswith(start_token):
|
||||
# tokens[i] == start_token
|
||||
if i == len(tokens) - 1:
|
||||
# it is the last token
|
||||
flag = False
|
||||
elif tokens[i + 1].startswith(start_token):
|
||||
# the next token also starts with start_token
|
||||
flag = False
|
||||
if flag:
|
||||
ans.append(timestamp[i])
|
||||
return ans
|
||||
|
||||
|
||||
def parse_hyp_and_timestamp(
|
||||
res: DecodingResults,
|
||||
decoding_method: str,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
subsampling_factor: int,
|
||||
frame_shift_ms: float = 10,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
) -> Tuple[List[List[str]], List[List[float]]]:
|
||||
"""Parse hypothesis and timestamp.
|
||||
|
||||
Args:
|
||||
res:
|
||||
A DecodingResults object.
|
||||
decoding_method:
|
||||
Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
sp:
|
||||
The BPE model.
|
||||
subsampling_factor:
|
||||
The integer subsampling factor.
|
||||
frame_shift_ms:
|
||||
The float frame shift used for feature extraction.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
|
||||
Returns:
|
||||
Return a list of hypothesis and timestamp.
|
||||
"""
|
||||
assert decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_nbest",
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
)
|
||||
|
||||
hyps = []
|
||||
timestamps = []
|
||||
|
||||
N = len(res.tokens)
|
||||
assert len(res.timestamps) == N
|
||||
use_word_table = False
|
||||
if decoding_method == "fast_beam_search_nbest_LG":
|
||||
assert word_table is not None
|
||||
use_word_table = True
|
||||
|
||||
for i in range(N):
|
||||
tokens = sp.id_to_piece(res.tokens[i])
|
||||
if use_word_table:
|
||||
words = [word_table[i] for i in res.hyps[i]]
|
||||
else:
|
||||
words = sp.decode_pieces(tokens).split()
|
||||
time = convert_timestamp(
|
||||
res.timestamps[i], subsampling_factor, frame_shift_ms
|
||||
)
|
||||
time = parse_timestamp(tokens, time)
|
||||
assert len(time) == len(words), (tokens, words)
|
||||
|
||||
hyps.append(words)
|
||||
timestamps.append(time)
|
||||
|
||||
return hyps, timestamps
|
||||
|
||||
|
||||
# `is_module_available` is copied from
|
||||
# https://github.com/pytorch/audio/blob/6bad3a66a7a1c7cc05755e9ee5931b7391d2b94c/torchaudio/_internal/module_utils.py#L9
|
||||
def is_module_available(*modules: str) -> bool:
|
||||
r"""Returns if a top-level module with :attr:`name` exists *without**
|
||||
importing it. This is generally safer than try-catch block around a
|
||||
`import X`.
|
||||
|
||||
Note: "borrowed" from torchaudio:
|
||||
"""
|
||||
import importlib
|
||||
|
||||
return all(importlib.util.find_spec(m) is not None for m in modules)
|
||||
|
@ -23,4 +23,4 @@ multi_quantization
|
||||
|
||||
onnx
|
||||
onnxruntime
|
||||
onnx_graphsurgeon -i https://pypi.ngc.nvidia.com
|
||||
kaldifst
|
||||
|
@ -3,8 +3,4 @@ kaldialign
|
||||
sentencepiece>=0.1.96
|
||||
tensorboard
|
||||
typeguard
|
||||
multi_quantization
|
||||
onnx
|
||||
onnxruntime
|
||||
--extra-index-url https://pypi.ngc.nvidia.com
|
||||
dill
|
||||
|
74
test/test_ngram_lm.py
Executable file
74
test/test_ngram_lm.py
Executable file
@ -0,0 +1,74 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import graphviz
|
||||
|
||||
from icefall import is_module_available
|
||||
|
||||
if not is_module_available("kaldifst"):
|
||||
raise ValueError("Please 'pip install kaldifst' first.")
|
||||
|
||||
import kaldifst
|
||||
|
||||
from icefall import NgramLm, NgramLmStateCost
|
||||
|
||||
|
||||
def generate_fst(filename: str):
|
||||
s = """
|
||||
3 5 1 1 3.00464
|
||||
3 0 3 0 5.75646
|
||||
0 1 1 1 12.0533
|
||||
0 2 2 2 7.95954
|
||||
0 9.97787
|
||||
1 4 2 2 3.35436
|
||||
1 0 3 0 7.59853
|
||||
2 0 3 0
|
||||
4 2 3 0 7.43735
|
||||
4 0.551239
|
||||
5 4 2 2 0.804938
|
||||
5 1 3 0 9.67086
|
||||
"""
|
||||
fst = kaldifst.compile(s=s, acceptor=False)
|
||||
fst.write(filename)
|
||||
fst_dot = kaldifst.draw(fst, acceptor=False, portrait=True)
|
||||
source = graphviz.Source(fst_dot)
|
||||
source.render(outfile=f"{filename}.svg")
|
||||
|
||||
|
||||
def main():
|
||||
filename = "test.fst"
|
||||
generate_fst(filename)
|
||||
ngram_lm = NgramLm(filename, backoff_id=3, is_binary=True)
|
||||
for label in [1, 2, 3, 4, 5]:
|
||||
print("---label---", label)
|
||||
p = ngram_lm.get_next_state_and_cost(state=5, label=label)
|
||||
print(p)
|
||||
print("---")
|
||||
|
||||
state_cost = NgramLmStateCost(ngram_lm)
|
||||
s0 = state_cost.forward_one_step(1)
|
||||
print(s0.state_cost)
|
||||
|
||||
s1 = s0.forward_one_step(2)
|
||||
print(s1.state_cost)
|
||||
|
||||
s2 = s1.forward_one_step(2)
|
||||
print(s2.state_cost)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user