mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-14 04:22:21 +00:00
Merge branch 'k2-fsa:master' into master
This commit is contained in:
commit
3b142dbef3
@ -13,9 +13,12 @@ cd egs/librispeech/ASR
|
|||||||
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-04-29
|
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-04-29
|
||||||
|
|
||||||
log "Downloading pre-trained model from $repo_url"
|
log "Downloading pre-trained model from $repo_url"
|
||||||
git lfs install
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
git clone $repo_url
|
|
||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||||
|
git lfs pull --include "exp/pretrained-epoch-25-avg-6.pt"
|
||||||
|
popd
|
||||||
|
|
||||||
log "Display test files"
|
log "Display test files"
|
||||||
tree $repo/
|
tree $repo/
|
||||||
|
@ -52,6 +52,9 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
|||||||
|
|
||||||
if [ ! -f $dl_dir/lm/3-gram.unpruned.arpa ]; then
|
if [ ! -f $dl_dir/lm/3-gram.unpruned.arpa ]; then
|
||||||
git clone https://huggingface.co/pkufool/aishell_lm $dl_dir/lm
|
git clone https://huggingface.co/pkufool/aishell_lm $dl_dir/lm
|
||||||
|
pushd $dl_dir/lm
|
||||||
|
git lfs pull --include "3-gram.unpruned.arpa"
|
||||||
|
popd
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
3
egs/csj/ASR/.gitignore
vendored
3
egs/csj/ASR/.gitignore
vendored
@ -1,7 +1,8 @@
|
|||||||
librispeech_*.*
|
librispeech_*
|
||||||
todelete*
|
todelete*
|
||||||
lang*
|
lang*
|
||||||
notify_tg.py
|
notify_tg.py
|
||||||
finetune_*
|
finetune_*
|
||||||
misc.ini
|
misc.ini
|
||||||
.vscode/*
|
.vscode/*
|
||||||
|
offline/*
|
@ -1 +0,0 @@
|
|||||||
../../../librispeech/ASR/local/compute_fbank_musan.py
|
|
127
egs/csj/ASR/local/compute_fbank_musan.py
Normal file
127
egs/csj/ASR/local/compute_fbank_musan.py
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, combine
|
||||||
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
|
from icefall.utils import get_executor
|
||||||
|
|
||||||
|
|
||||||
|
ARGPARSE_DESCRIPTION = """
|
||||||
|
This file computes fbank features of the musan dataset.
|
||||||
|
It looks for manifests in the directory data/manifests.
|
||||||
|
|
||||||
|
The generated fbank features are saved in data/fbank.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
|
# it wastes a lot of CPU and slow things down.
|
||||||
|
# Do this outside of main() in case it needs to take effect
|
||||||
|
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_fbank_musan(manifest_dir: Path, fbank_dir: Path):
|
||||||
|
# src_dir = Path("data/manifests")
|
||||||
|
# output_dir = Path("data/fbank")
|
||||||
|
num_jobs = min(15, os.cpu_count())
|
||||||
|
num_mel_bins = 80
|
||||||
|
|
||||||
|
dataset_parts = (
|
||||||
|
"music",
|
||||||
|
"speech",
|
||||||
|
"noise",
|
||||||
|
)
|
||||||
|
prefix = "musan"
|
||||||
|
suffix = "jsonl.gz"
|
||||||
|
manifests = read_manifests_if_cached(
|
||||||
|
dataset_parts=dataset_parts,
|
||||||
|
output_dir=manifest_dir,
|
||||||
|
prefix=prefix,
|
||||||
|
suffix=suffix,
|
||||||
|
)
|
||||||
|
assert manifests is not None
|
||||||
|
|
||||||
|
assert len(manifests) == len(dataset_parts), (
|
||||||
|
len(manifests),
|
||||||
|
len(dataset_parts),
|
||||||
|
list(manifests.keys()),
|
||||||
|
dataset_parts,
|
||||||
|
)
|
||||||
|
|
||||||
|
musan_cuts_path = fbank_dir / "musan_cuts.jsonl.gz"
|
||||||
|
|
||||||
|
if musan_cuts_path.is_file():
|
||||||
|
logging.info(f"{musan_cuts_path} already exists - skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
logging.info("Extracting features for Musan")
|
||||||
|
|
||||||
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
|
# create chunks of Musan with duration 5 - 10 seconds
|
||||||
|
musan_cuts = (
|
||||||
|
CutSet.from_manifests(
|
||||||
|
recordings=combine(
|
||||||
|
part["recordings"] for part in manifests.values()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.cut_into_windows(10.0)
|
||||||
|
.filter(lambda c: c.duration > 5)
|
||||||
|
.compute_and_store_features(
|
||||||
|
extractor=extractor,
|
||||||
|
storage_path=f"{fbank_dir}/musan_feats",
|
||||||
|
num_jobs=num_jobs if ex is None else 80,
|
||||||
|
executor=ex,
|
||||||
|
storage_type=LilcomChunkyWriter,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
musan_cuts.to_file(musan_cuts_path)
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=ARGPARSE_DESCRIPTION,
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--manifest-dir", type=Path, help="Path to save manifests"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fbank-dir", type=Path, help="Path to save fbank features"
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = get_args()
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
compute_fbank_musan(args.manifest_dir, args.fbank_dir)
|
@ -64,9 +64,9 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
|||||||
# Example: lhotse prepare csj $csj_dir $trans_dir $csj_manifest_dir -c local/conf/disfluent.ini
|
# Example: lhotse prepare csj $csj_dir $trans_dir $csj_manifest_dir -c local/conf/disfluent.ini
|
||||||
# NOTE: In case multiple config files are supplied, the second config file and onwards will inherit
|
# NOTE: In case multiple config files are supplied, the second config file and onwards will inherit
|
||||||
# the segment boundaries of the first config file.
|
# the segment boundaries of the first config file.
|
||||||
if [ ! -e $csj_manifest_dir/.librispeech.done ]; then
|
if [ ! -e $csj_manifest_dir/.csj.done ]; then
|
||||||
lhotse prepare csj $csj_dir $trans_dir $csj_manifest_dir -j 4
|
lhotse prepare csj $csj_dir $trans_dir $csj_manifest_dir -j 4
|
||||||
touch $csj_manifest_dir/.librispeech.done
|
touch $csj_manifest_dir/.csj.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
@ -1304,7 +1304,7 @@ results at:
|
|||||||
|
|
||||||
##### Baseline-2
|
##### Baseline-2
|
||||||
|
|
||||||
It has 88.98 M parameters. Compared to the model in pruned_transducer_stateless2, its has more
|
It has 87.8 M parameters. Compared to the model in pruned_transducer_stateless2, its has more
|
||||||
layers (24 v.s 12) but a narrower model (1536 feedforward dim and 384 encoder dim vs 2048 feed forward dim and 512 encoder dim).
|
layers (24 v.s 12) but a narrower model (1536 feedforward dim and 384 encoder dim vs 2048 feed forward dim and 512 encoder dim).
|
||||||
|
|
||||||
| | test-clean | test-other | comment |
|
| | test-clean | test-other | comment |
|
||||||
|
12
egs/librispeech/ASR/add_alignments.sh
Executable file
12
egs/librispeech/ASR/add_alignments.sh
Executable file
@ -0,0 +1,12 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -eou pipefail
|
||||||
|
|
||||||
|
alignments_dir=data/alignment
|
||||||
|
cuts_in_dir=data/fbank
|
||||||
|
cuts_out_dir=data/fbank_ali
|
||||||
|
|
||||||
|
python3 ./local/add_alignment_librispeech.py \
|
||||||
|
--alignments-dir $alignments_dir \
|
||||||
|
--cuts-in-dir $cuts_in_dir \
|
||||||
|
--cuts-out-dir $cuts_out_dir
|
20
egs/librispeech/ASR/generate-lm.sh
Executable file
20
egs/librispeech/ASR/generate-lm.sh
Executable file
@ -0,0 +1,20 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
lang_dir=data/lang_bpe_500
|
||||||
|
|
||||||
|
for ngram in 2 3 5; do
|
||||||
|
if [ ! -f $lang_dir/${ngram}gram.arpa ]; then
|
||||||
|
./shared/make_kn_lm.py \
|
||||||
|
-ngram-order ${ngram} \
|
||||||
|
-text $lang_dir/transcript_tokens.txt \
|
||||||
|
-lm $lang_dir/${ngram}gram.arpa
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f $lang_dir/${ngram}gram.fst.txt ]; then
|
||||||
|
python3 -m kaldilm \
|
||||||
|
--read-symbol-table="$lang_dir/tokens.txt" \
|
||||||
|
--disambig-symbol='#0' \
|
||||||
|
--max-order=${ngram} \
|
||||||
|
$lang_dir/${ngram}gram.arpa > $lang_dir/${ngram}gram.fst.txt
|
||||||
|
fi
|
||||||
|
done
|
196
egs/librispeech/ASR/local/add_alignment_librispeech.py
Executable file
196
egs/librispeech/ASR/local/add_alignment_librispeech.py
Executable file
@ -0,0 +1,196 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file adds alignments from https://github.com/CorentinJ/librispeech-alignments # noqa
|
||||||
|
to the existing fbank features dir (e.g., data/fbank)
|
||||||
|
and save cuts to a new dir (e.g., data/fbank_ali).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import zipfile
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from lhotse import CutSet, load_manifest_lazy
|
||||||
|
from lhotse.recipes.librispeech import parse_alignments
|
||||||
|
from lhotse.utils import is_module_available
|
||||||
|
|
||||||
|
LIBRISPEECH_ALIGNMENTS_URL = (
|
||||||
|
"https://drive.google.com/uc?id=1WYfgr31T-PPwMcxuAq09XZfHQO5Mw8fE"
|
||||||
|
)
|
||||||
|
|
||||||
|
DATASET_PARTS = [
|
||||||
|
"dev-clean",
|
||||||
|
"dev-other",
|
||||||
|
"test-clean",
|
||||||
|
"test-other",
|
||||||
|
"train-clean-100",
|
||||||
|
"train-clean-360",
|
||||||
|
"train-other-500",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--alignments-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/alignment",
|
||||||
|
help="The dir to save alignments.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--cuts-in-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/fbank",
|
||||||
|
help="The dir of the existing cuts without alignments.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--cuts-out-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/fbank_ali",
|
||||||
|
help="The dir to save the new cuts with alignments",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def download_alignments(
|
||||||
|
target_dir: str, alignments_url: str = LIBRISPEECH_ALIGNMENTS_URL
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Download and extract the alignments.
|
||||||
|
|
||||||
|
Note: If you can not access drive.google.com, you could download the file
|
||||||
|
`LibriSpeech-Alignments.zip` from huggingface:
|
||||||
|
https://huggingface.co/Zengwei/librispeech-alignments
|
||||||
|
and extract the zip file manually.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_dir:
|
||||||
|
The dir to save alignments.
|
||||||
|
alignments_url:
|
||||||
|
The URL of alignments.
|
||||||
|
"""
|
||||||
|
"""Modified from https://github.com/lhotse-speech/lhotse/blob/master/lhotse/recipes/librispeech.py""" # noqa
|
||||||
|
target_dir = Path(target_dir)
|
||||||
|
target_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
completed_detector = target_dir / ".ali_completed"
|
||||||
|
if completed_detector.is_file():
|
||||||
|
logging.info("The alignment files already exist.")
|
||||||
|
return
|
||||||
|
|
||||||
|
ali_zip_path = target_dir / "LibriSpeech-Alignments.zip"
|
||||||
|
if not ali_zip_path.is_file():
|
||||||
|
assert is_module_available(
|
||||||
|
"gdown"
|
||||||
|
), 'To download LibriSpeech alignments, please install "pip install gdown"' # noqa
|
||||||
|
import gdown
|
||||||
|
|
||||||
|
gdown.download(alignments_url, output=str(ali_zip_path))
|
||||||
|
|
||||||
|
with zipfile.ZipFile(str(ali_zip_path)) as f:
|
||||||
|
f.extractall(path=target_dir)
|
||||||
|
completed_detector.touch()
|
||||||
|
|
||||||
|
|
||||||
|
def add_alignment(
|
||||||
|
alignments_dir: str,
|
||||||
|
cuts_in_dir: str = "data/fbank",
|
||||||
|
cuts_out_dir: str = "data/fbank_ali",
|
||||||
|
dataset_parts: List[str] = DATASET_PARTS,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add alignment info to existing cuts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alignments_dir:
|
||||||
|
The dir of the alignments.
|
||||||
|
cuts_in_dir:
|
||||||
|
The dir of the existing cuts.
|
||||||
|
cuts_out_dir:
|
||||||
|
The dir to save the new cuts with alignments.
|
||||||
|
dataset_parts:
|
||||||
|
Librispeech parts to add alignments.
|
||||||
|
"""
|
||||||
|
alignments_dir = Path(alignments_dir)
|
||||||
|
cuts_in_dir = Path(cuts_in_dir)
|
||||||
|
cuts_out_dir = Path(cuts_out_dir)
|
||||||
|
cuts_out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
for part in dataset_parts:
|
||||||
|
logging.info(f"Processing {part}")
|
||||||
|
|
||||||
|
cuts_in_path = cuts_in_dir / f"librispeech_cuts_{part}.jsonl.gz"
|
||||||
|
if not cuts_in_path.is_file():
|
||||||
|
logging.info(f"{cuts_in_path} does not exist - skipping.")
|
||||||
|
continue
|
||||||
|
cuts_out_path = cuts_out_dir / f"librispeech_cuts_{part}.jsonl.gz"
|
||||||
|
if cuts_out_path.is_file():
|
||||||
|
logging.info(f"{part} already exists - skipping.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# parse alignments
|
||||||
|
alignments = {}
|
||||||
|
part_ali_dir = alignments_dir / "LibriSpeech" / part
|
||||||
|
for ali_path in part_ali_dir.rglob("*.alignment.txt"):
|
||||||
|
ali = parse_alignments(ali_path)
|
||||||
|
alignments.update(ali)
|
||||||
|
logging.info(
|
||||||
|
f"{part} has {len(alignments.keys())} cuts with alignments."
|
||||||
|
)
|
||||||
|
|
||||||
|
# add alignment attribute and write out
|
||||||
|
cuts_in = load_manifest_lazy(cuts_in_path)
|
||||||
|
with CutSet.open_writer(cuts_out_path) as writer:
|
||||||
|
for cut in cuts_in:
|
||||||
|
for idx, subcut in enumerate(cut.supervisions):
|
||||||
|
origin_id = subcut.id.split("_")[0]
|
||||||
|
if origin_id in alignments:
|
||||||
|
ali = alignments[origin_id]
|
||||||
|
else:
|
||||||
|
logging.info(
|
||||||
|
f"Warning: {origin_id} does not has alignment."
|
||||||
|
)
|
||||||
|
ali = []
|
||||||
|
subcut.alignment = {"word": ali}
|
||||||
|
writer.write(cut, flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
|
||||||
|
download_alignments(args.alignments_dir)
|
||||||
|
add_alignment(args.alignments_dir, args.cuts_in_dir, args.cuts_out_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -115,10 +115,12 @@ from beam_search import (
|
|||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
|
modified_beam_search_ngram_rescoring,
|
||||||
)
|
)
|
||||||
from librispeech import LibriSpeech
|
from librispeech import LibriSpeech
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall import NgramLm
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
@ -214,6 +216,7 @@ def get_parser():
|
|||||||
- fast_beam_search_nbest
|
- fast_beam_search_nbest
|
||||||
- fast_beam_search_nbest_oracle
|
- fast_beam_search_nbest_oracle
|
||||||
- fast_beam_search_nbest_LG
|
- fast_beam_search_nbest_LG
|
||||||
|
- modified_beam_search_ngram_rescoring
|
||||||
If you use fast_beam_search_nbest_LG, you have to specify
|
If you use fast_beam_search_nbest_LG, you have to specify
|
||||||
`--lang-dir`, which should contain `LG.pt`.
|
`--lang-dir`, which should contain `LG.pt`.
|
||||||
""",
|
""",
|
||||||
@ -303,6 +306,22 @@ def get_parser():
|
|||||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens-ngram",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="""Token Ngram used for rescoring.
|
||||||
|
Used only when the decoding method is modified_beam_search_ngram_rescoring""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--backoff-id",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help="""ID of the backoff symbol.
|
||||||
|
Used only when the decoding method is modified_beam_search_ngram_rescoring""",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -315,6 +334,8 @@ def decode_one_batch(
|
|||||||
batch: dict,
|
batch: dict,
|
||||||
word_table: Optional[k2.SymbolTable] = None,
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
ngram_lm: Optional[NgramLm] = None,
|
||||||
|
ngram_lm_scale: float = 1.0,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -448,6 +469,17 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
|
elif params.decoding_method == "modified_beam_search_ngram_rescoring":
|
||||||
|
hyp_tokens = modified_beam_search_ngram_rescoring(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
ngram_lm=ngram_lm,
|
||||||
|
ngram_lm_scale=ngram_lm_scale,
|
||||||
|
beam=params.beam_size,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
else:
|
else:
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
|
|
||||||
@ -497,6 +529,8 @@ def decode_dataset(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
word_table: Optional[k2.SymbolTable] = None,
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
ngram_lm: Optional[NgramLm] = None,
|
||||||
|
ngram_lm_scale: float = 1.0,
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -546,6 +580,8 @@ def decode_dataset(
|
|||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
word_table=word_table,
|
word_table=word_table,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
|
ngram_lm=ngram_lm,
|
||||||
|
ngram_lm_scale=ngram_lm_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
@ -631,6 +667,7 @@ def main():
|
|||||||
"fast_beam_search_nbest_LG",
|
"fast_beam_search_nbest_LG",
|
||||||
"fast_beam_search_nbest_oracle",
|
"fast_beam_search_nbest_oracle",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
|
"modified_beam_search_ngram_rescoring",
|
||||||
)
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
@ -655,6 +692,7 @@ def main():
|
|||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
|
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||||
|
|
||||||
if params.use_averaged_model:
|
if params.use_averaged_model:
|
||||||
params.suffix += "-use-averaged-model"
|
params.suffix += "-use-averaged-model"
|
||||||
@ -768,6 +806,15 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
|
||||||
|
logging.info(f"lm filename: {lm_filename}")
|
||||||
|
ngram_lm = NgramLm(
|
||||||
|
str(params.lang_dir / lm_filename),
|
||||||
|
backoff_id=params.backoff_id,
|
||||||
|
is_binary=False,
|
||||||
|
)
|
||||||
|
logging.info(f"num states: {ngram_lm.lm.num_states}")
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "fast_beam_search" in params.decoding_method:
|
||||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
@ -812,6 +859,8 @@ def main():
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
word_table=word_table,
|
word_table=word_table,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
|
ngram_lm=ngram_lm,
|
||||||
|
ngram_lm_scale=params.ngram_lm_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
|
@ -42,6 +42,11 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from icefall import is_module_available
|
||||||
|
|
||||||
|
if not is_module_available("onnxruntime"):
|
||||||
|
raise ValueError("Please 'pip install onnxruntime' first.")
|
||||||
|
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
|
@ -91,6 +91,22 @@ Usage:
|
|||||||
--beam 20.0 \
|
--beam 20.0 \
|
||||||
--max-contexts 8 \
|
--max-contexts 8 \
|
||||||
--max-states 64
|
--max-states 64
|
||||||
|
|
||||||
|
To evaluate symbol delay, you should:
|
||||||
|
(1) Generate cuts with word-time alignments:
|
||||||
|
./local/add_alignment_librispeech.py \
|
||||||
|
--alignments-dir data/alignment \
|
||||||
|
--cuts-in-dir data/fbank \
|
||||||
|
--cuts-out-dir data/fbank_ali
|
||||||
|
(2) Set the argument "--manifest-dir data/fbank_ali" while decoding.
|
||||||
|
For example:
|
||||||
|
./lstm_transducer_stateless3/decode.py \
|
||||||
|
--epoch 40 \
|
||||||
|
--avg 20 \
|
||||||
|
--exp-dir ./lstm_transducer_stateless3/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method greedy_search \
|
||||||
|
--manifest-dir data/fbank_ali
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -127,10 +143,12 @@ from icefall.checkpoint import (
|
|||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
|
DecodingResults,
|
||||||
|
parse_hyp_and_timestamp,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts_and_timestamps,
|
||||||
str2bool,
|
str2bool,
|
||||||
write_error_stats,
|
write_error_stats_with_timestamps,
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG_EPS = math.log(1e-10)
|
LOG_EPS = math.log(1e-10)
|
||||||
@ -314,7 +332,7 @@ def decode_one_batch(
|
|||||||
batch: dict,
|
batch: dict,
|
||||||
word_table: Optional[k2.SymbolTable] = None,
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
|
|
||||||
@ -322,9 +340,11 @@ def decode_one_batch(
|
|||||||
if greedy_search is used, it would be "greedy_search"
|
if greedy_search is used, it would be "greedy_search"
|
||||||
If beam search with a beam size of 7 is used, it would be
|
If beam search with a beam size of 7 is used, it would be
|
||||||
"beam_7"
|
"beam_7"
|
||||||
- value: It contains the decoding result. `len(value)` equals to
|
- value: It is a tuple. `len(value[0])` and `len(value[1])` are both
|
||||||
batch size. `value[i]` is the decoding result for the i-th
|
equal to the batch size. `value[0][i]` and `value[1][i]`
|
||||||
utterance in the given batch.
|
are the decoding result and timestamps for the i-th utterance
|
||||||
|
in the given batch respectively.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params:
|
params:
|
||||||
It's the return value of :func:`get_params`.
|
It's the return value of :func:`get_params`.
|
||||||
@ -343,8 +363,8 @@ def decode_one_batch(
|
|||||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoding result. See above description for the format of
|
Return the decoding result and timestamps. See above description for the
|
||||||
the returned dict.
|
format of the returned dict.
|
||||||
"""
|
"""
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
@ -370,10 +390,8 @@ def decode_one_batch(
|
|||||||
x=feature, x_lens=feature_lens
|
x=feature, x_lens=feature_lens
|
||||||
)
|
)
|
||||||
|
|
||||||
hyps = []
|
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
res = fast_beam_search_one_best(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -381,11 +399,10 @@ def decode_one_batch(
|
|||||||
beam=params.beam,
|
beam=params.beam,
|
||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||||
hyp_tokens = fast_beam_search_nbest_LG(
|
res = fast_beam_search_nbest_LG(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -395,11 +412,10 @@ def decode_one_batch(
|
|||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
nbest_scale=params.nbest_scale,
|
nbest_scale=params.nbest_scale,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
for hyp in hyp_tokens:
|
|
||||||
hyps.append([word_table[i] for i in hyp])
|
|
||||||
elif params.decoding_method == "fast_beam_search_nbest":
|
elif params.decoding_method == "fast_beam_search_nbest":
|
||||||
hyp_tokens = fast_beam_search_nbest(
|
res = fast_beam_search_nbest(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -409,11 +425,10 @@ def decode_one_batch(
|
|||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
nbest_scale=params.nbest_scale,
|
nbest_scale=params.nbest_scale,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
res = fast_beam_search_nbest_oracle(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -424,56 +439,67 @@ def decode_one_batch(
|
|||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
ref_texts=sp.encode(supervisions["text"]),
|
ref_texts=sp.encode(supervisions["text"]),
|
||||||
nbest_scale=params.nbest_scale,
|
nbest_scale=params.nbest_scale,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
elif (
|
elif (
|
||||||
params.decoding_method == "greedy_search"
|
params.decoding_method == "greedy_search"
|
||||||
and params.max_sym_per_frame == 1
|
and params.max_sym_per_frame == 1
|
||||||
):
|
):
|
||||||
hyp_tokens = greedy_search_batch(
|
res = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
elif params.decoding_method == "modified_beam_search":
|
elif params.decoding_method == "modified_beam_search":
|
||||||
hyp_tokens = modified_beam_search(
|
res = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
else:
|
else:
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
|
tokens = []
|
||||||
|
timestamps = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
|
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
hyp = greedy_search(
|
res = greedy_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out_i,
|
encoder_out=encoder_out_i,
|
||||||
max_sym_per_frame=params.max_sym_per_frame,
|
max_sym_per_frame=params.max_sym_per_frame,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
elif params.decoding_method == "beam_search":
|
elif params.decoding_method == "beam_search":
|
||||||
hyp = beam_search(
|
res = beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out_i,
|
encoder_out=encoder_out_i,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
)
|
)
|
||||||
hyps.append(sp.decode(hyp).split())
|
tokens.extend(res.tokens)
|
||||||
|
timestamps.extend(res.timestamps)
|
||||||
|
res = DecodingResults(tokens=tokens, timestamps=timestamps)
|
||||||
|
|
||||||
|
hyps, timestamps = parse_hyp_and_timestamp(
|
||||||
|
decoding_method=params.decoding_method,
|
||||||
|
res=res,
|
||||||
|
sp=sp,
|
||||||
|
subsampling_factor=params.subsampling_factor,
|
||||||
|
frame_shift_ms=params.frame_shift_ms,
|
||||||
|
word_table=word_table,
|
||||||
|
)
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
return {"greedy_search": hyps}
|
return {"greedy_search": (hyps, timestamps)}
|
||||||
elif "fast_beam_search" in params.decoding_method:
|
elif "fast_beam_search" in params.decoding_method:
|
||||||
key = f"beam_{params.beam}_"
|
key = f"beam_{params.beam}_"
|
||||||
key += f"max_contexts_{params.max_contexts}_"
|
key += f"max_contexts_{params.max_contexts}_"
|
||||||
@ -484,9 +510,9 @@ def decode_one_batch(
|
|||||||
if "LG" in params.decoding_method:
|
if "LG" in params.decoding_method:
|
||||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||||
|
|
||||||
return {key: hyps}
|
return {key: (hyps, timestamps)}
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}": hyps}
|
return {f"beam_size_{params.beam_size}": (hyps, timestamps)}
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
@ -496,7 +522,9 @@ def decode_dataset(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
word_table: Optional[k2.SymbolTable] = None,
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
) -> Dict[
|
||||||
|
str, List[Tuple[str, List[str], List[str], List[float], List[float]]]
|
||||||
|
]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -517,9 +545,12 @@ def decode_dataset(
|
|||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
Its value is a list of tuples. Each tuple contains two elements:
|
Its value is a list of tuples. Each tuple contains five elements:
|
||||||
The first is the reference transcript, and the second is the
|
- cut_id
|
||||||
predicted result.
|
- reference transcript
|
||||||
|
- predicted result
|
||||||
|
- timestamp of reference transcript
|
||||||
|
- timestamp of predicted result
|
||||||
"""
|
"""
|
||||||
num_cuts = 0
|
num_cuts = 0
|
||||||
|
|
||||||
@ -538,6 +569,18 @@ def decode_dataset(
|
|||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
|
timestamps_ref = []
|
||||||
|
for cut in batch["supervisions"]["cut"]:
|
||||||
|
for s in cut.supervisions:
|
||||||
|
time = []
|
||||||
|
if s.alignment is not None and "word" in s.alignment:
|
||||||
|
time = [
|
||||||
|
aliword.start
|
||||||
|
for aliword in s.alignment["word"]
|
||||||
|
if aliword.symbol != ""
|
||||||
|
]
|
||||||
|
timestamps_ref.append(time)
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -547,12 +590,18 @@ def decode_dataset(
|
|||||||
batch=batch,
|
batch=batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, hyps in hyps_dict.items():
|
for name, (hyps, timestamps_hyp) in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts) and len(timestamps_hyp) == len(
|
||||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
timestamps_ref
|
||||||
|
)
|
||||||
|
for cut_id, hyp_words, ref_text, time_hyp, time_ref in zip(
|
||||||
|
cut_ids, hyps, texts, timestamps_hyp, timestamps_ref
|
||||||
|
):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((cut_id, ref_words, hyp_words))
|
this_batch.append(
|
||||||
|
(cut_id, ref_words, hyp_words, time_ref, time_hyp)
|
||||||
|
)
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -570,15 +619,19 @@ def decode_dataset(
|
|||||||
def save_results(
|
def save_results(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
results_dict: Dict[
|
||||||
|
str,
|
||||||
|
List[Tuple[List[str], List[str], List[str], List[float], List[float]]],
|
||||||
|
],
|
||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
|
test_set_delays = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts_and_timestamps(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
@ -587,10 +640,11 @@ def save_results(
|
|||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer, mean_delay, var_delay = write_error_stats_with_timestamps(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
)
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
test_set_delays[key] = (mean_delay, var_delay)
|
||||||
|
|
||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
@ -604,6 +658,19 @@ def save_results(
|
|||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
print("{}\t{}".format(key, val), file=f)
|
print("{}\t{}".format(key, val), file=f)
|
||||||
|
|
||||||
|
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0])
|
||||||
|
delays_info = (
|
||||||
|
params.res_dir
|
||||||
|
/ f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
|
with open(delays_info, "w") as f:
|
||||||
|
print("settings\tsymbol-delay", file=f)
|
||||||
|
for key, val in test_set_delays:
|
||||||
|
print(
|
||||||
|
"{}\tmean: {}s, variance: {}".format(key, val[0], val[1]),
|
||||||
|
file=f,
|
||||||
|
)
|
||||||
|
|
||||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||||
note = "\tbest for {}".format(test_set_name)
|
note = "\tbest for {}".format(test_set_name)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
@ -611,6 +678,15 @@ def save_results(
|
|||||||
note = ""
|
note = ""
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
|
s = "\nFor {}, symbol-delay of different settings are:\n".format(
|
||||||
|
test_set_name
|
||||||
|
)
|
||||||
|
note = "\tbest for {}".format(test_set_name)
|
||||||
|
for key, val in test_set_delays:
|
||||||
|
s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note)
|
||||||
|
note = ""
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
|
@ -377,6 +377,7 @@ def get_params() -> AttributeDict:
|
|||||||
"""
|
"""
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
|
"frame_shift_ms": 10.0,
|
||||||
"best_train_loss": float("inf"),
|
"best_train_loss": float("inf"),
|
||||||
"best_valid_loss": float("inf"),
|
"best_valid_loss": float("inf"),
|
||||||
"best_train_epoch": -1,
|
"best_train_epoch": -1,
|
||||||
|
@ -511,7 +511,7 @@ def decode_dataset(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
word_table: Optional[k2.SymbolTable] = None,
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[str, Tuple[str, List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -585,7 +585,7 @@ def decode_dataset(
|
|||||||
def save_results(
|
def save_results(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
results_dict: Dict[str, List[str, Tuple[str, List[str], List[str]]]],
|
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
|
@ -16,15 +16,22 @@
|
|||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
|
|
||||||
|
from icefall import NgramLm, NgramLmStateCost
|
||||||
from icefall.decode import Nbest, one_best_decoding
|
from icefall.decode import Nbest, one_best_decoding
|
||||||
from icefall.utils import add_eos, add_sos, get_texts
|
from icefall.utils import (
|
||||||
|
DecodingResults,
|
||||||
|
add_eos,
|
||||||
|
add_sos,
|
||||||
|
get_texts,
|
||||||
|
get_texts_with_timestamp,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def fast_beam_search_one_best(
|
def fast_beam_search_one_best(
|
||||||
@ -36,7 +43,8 @@ def fast_beam_search_one_best(
|
|||||||
max_states: int,
|
max_states: int,
|
||||||
max_contexts: int,
|
max_contexts: int,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
) -> List[List[int]]:
|
return_timestamps: bool = False,
|
||||||
|
) -> Union[List[List[int]], DecodingResults]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
A lattice is first obtained using fast beam search, and then
|
A lattice is first obtained using fast beam search, and then
|
||||||
@ -60,8 +68,12 @@ def fast_beam_search_one_best(
|
|||||||
Max contexts pre stream per frame.
|
Max contexts pre stream per frame.
|
||||||
temperature:
|
temperature:
|
||||||
Softmax temperature.
|
Softmax temperature.
|
||||||
|
return_timestamps:
|
||||||
|
Whether to return timestamps.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result.
|
If return_timestamps is False, return the decoded result.
|
||||||
|
Else, return a DecodingResults object containing
|
||||||
|
decoded result and corresponding timestamps.
|
||||||
"""
|
"""
|
||||||
lattice = fast_beam_search(
|
lattice = fast_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -75,8 +87,11 @@ def fast_beam_search_one_best(
|
|||||||
)
|
)
|
||||||
|
|
||||||
best_path = one_best_decoding(lattice)
|
best_path = one_best_decoding(lattice)
|
||||||
hyps = get_texts(best_path)
|
|
||||||
return hyps
|
if not return_timestamps:
|
||||||
|
return get_texts(best_path)
|
||||||
|
else:
|
||||||
|
return get_texts_with_timestamp(best_path)
|
||||||
|
|
||||||
|
|
||||||
def fast_beam_search_nbest_LG(
|
def fast_beam_search_nbest_LG(
|
||||||
@ -91,7 +106,8 @@ def fast_beam_search_nbest_LG(
|
|||||||
nbest_scale: float = 0.5,
|
nbest_scale: float = 0.5,
|
||||||
use_double_scores: bool = True,
|
use_double_scores: bool = True,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
) -> List[List[int]]:
|
return_timestamps: bool = False,
|
||||||
|
) -> Union[List[List[int]], DecodingResults]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
The process to get the results is:
|
The process to get the results is:
|
||||||
@ -128,8 +144,12 @@ def fast_beam_search_nbest_LG(
|
|||||||
single precision.
|
single precision.
|
||||||
temperature:
|
temperature:
|
||||||
Softmax temperature.
|
Softmax temperature.
|
||||||
|
return_timestamps:
|
||||||
|
Whether to return timestamps.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result.
|
If return_timestamps is False, return the decoded result.
|
||||||
|
Else, return a DecodingResults object containing
|
||||||
|
decoded result and corresponding timestamps.
|
||||||
"""
|
"""
|
||||||
lattice = fast_beam_search(
|
lattice = fast_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -194,9 +214,10 @@ def fast_beam_search_nbest_LG(
|
|||||||
best_hyp_indexes = ragged_tot_scores.argmax()
|
best_hyp_indexes = ragged_tot_scores.argmax()
|
||||||
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
|
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
|
||||||
|
|
||||||
hyps = get_texts(best_path)
|
if not return_timestamps:
|
||||||
|
return get_texts(best_path)
|
||||||
return hyps
|
else:
|
||||||
|
return get_texts_with_timestamp(best_path)
|
||||||
|
|
||||||
|
|
||||||
def fast_beam_search_nbest(
|
def fast_beam_search_nbest(
|
||||||
@ -211,7 +232,8 @@ def fast_beam_search_nbest(
|
|||||||
nbest_scale: float = 0.5,
|
nbest_scale: float = 0.5,
|
||||||
use_double_scores: bool = True,
|
use_double_scores: bool = True,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
) -> List[List[int]]:
|
return_timestamps: bool = False,
|
||||||
|
) -> Union[List[List[int]], DecodingResults]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
The process to get the results is:
|
The process to get the results is:
|
||||||
@ -248,8 +270,12 @@ def fast_beam_search_nbest(
|
|||||||
single precision.
|
single precision.
|
||||||
temperature:
|
temperature:
|
||||||
Softmax temperature.
|
Softmax temperature.
|
||||||
|
return_timestamps:
|
||||||
|
Whether to return timestamps.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result.
|
If return_timestamps is False, return the decoded result.
|
||||||
|
Else, return a DecodingResults object containing
|
||||||
|
decoded result and corresponding timestamps.
|
||||||
"""
|
"""
|
||||||
lattice = fast_beam_search(
|
lattice = fast_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -278,9 +304,10 @@ def fast_beam_search_nbest(
|
|||||||
|
|
||||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||||
|
|
||||||
hyps = get_texts(best_path)
|
if not return_timestamps:
|
||||||
|
return get_texts(best_path)
|
||||||
return hyps
|
else:
|
||||||
|
return get_texts_with_timestamp(best_path)
|
||||||
|
|
||||||
|
|
||||||
def fast_beam_search_nbest_oracle(
|
def fast_beam_search_nbest_oracle(
|
||||||
@ -296,7 +323,8 @@ def fast_beam_search_nbest_oracle(
|
|||||||
use_double_scores: bool = True,
|
use_double_scores: bool = True,
|
||||||
nbest_scale: float = 0.5,
|
nbest_scale: float = 0.5,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
) -> List[List[int]]:
|
return_timestamps: bool = False,
|
||||||
|
) -> Union[List[List[int]], DecodingResults]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
A lattice is first obtained using fast beam search, and then
|
A lattice is first obtained using fast beam search, and then
|
||||||
@ -337,8 +365,12 @@ def fast_beam_search_nbest_oracle(
|
|||||||
yields more unique paths.
|
yields more unique paths.
|
||||||
temperature:
|
temperature:
|
||||||
Softmax temperature.
|
Softmax temperature.
|
||||||
|
return_timestamps:
|
||||||
|
Whether to return timestamps.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result.
|
If return_timestamps is False, return the decoded result.
|
||||||
|
Else, return a DecodingResults object containing
|
||||||
|
decoded result and corresponding timestamps.
|
||||||
"""
|
"""
|
||||||
lattice = fast_beam_search(
|
lattice = fast_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -377,8 +409,10 @@ def fast_beam_search_nbest_oracle(
|
|||||||
|
|
||||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||||
|
|
||||||
hyps = get_texts(best_path)
|
if not return_timestamps:
|
||||||
return hyps
|
return get_texts(best_path)
|
||||||
|
else:
|
||||||
|
return get_texts_with_timestamp(best_path)
|
||||||
|
|
||||||
|
|
||||||
def fast_beam_search(
|
def fast_beam_search(
|
||||||
@ -468,8 +502,11 @@ def fast_beam_search(
|
|||||||
|
|
||||||
|
|
||||||
def greedy_search(
|
def greedy_search(
|
||||||
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
|
model: Transducer,
|
||||||
) -> List[int]:
|
encoder_out: torch.Tensor,
|
||||||
|
max_sym_per_frame: int,
|
||||||
|
return_timestamps: bool = False,
|
||||||
|
) -> Union[List[int], DecodingResults]:
|
||||||
"""Greedy search for a single utterance.
|
"""Greedy search for a single utterance.
|
||||||
Args:
|
Args:
|
||||||
model:
|
model:
|
||||||
@ -479,8 +516,12 @@ def greedy_search(
|
|||||||
max_sym_per_frame:
|
max_sym_per_frame:
|
||||||
Maximum number of symbols per frame. If it is set to 0, the WER
|
Maximum number of symbols per frame. If it is set to 0, the WER
|
||||||
would be 100%.
|
would be 100%.
|
||||||
|
return_timestamps:
|
||||||
|
Whether to return timestamps.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result.
|
If return_timestamps is False, return the decoded result.
|
||||||
|
Else, return a DecodingResults object containing
|
||||||
|
decoded result and corresponding timestamps.
|
||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == 3
|
assert encoder_out.ndim == 3
|
||||||
|
|
||||||
@ -506,6 +547,10 @@ def greedy_search(
|
|||||||
t = 0
|
t = 0
|
||||||
hyp = [blank_id] * context_size
|
hyp = [blank_id] * context_size
|
||||||
|
|
||||||
|
# timestamp[i] is the frame index after subsampling
|
||||||
|
# on which hyp[i] is decoded
|
||||||
|
timestamp = []
|
||||||
|
|
||||||
# Maximum symbols per utterance.
|
# Maximum symbols per utterance.
|
||||||
max_sym_per_utt = 1000
|
max_sym_per_utt = 1000
|
||||||
|
|
||||||
@ -532,6 +577,7 @@ def greedy_search(
|
|||||||
y = logits.argmax().item()
|
y = logits.argmax().item()
|
||||||
if y not in (blank_id, unk_id):
|
if y not in (blank_id, unk_id):
|
||||||
hyp.append(y)
|
hyp.append(y)
|
||||||
|
timestamp.append(t)
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
[hyp[-context_size:]], device=device
|
[hyp[-context_size:]], device=device
|
||||||
).reshape(1, context_size)
|
).reshape(1, context_size)
|
||||||
@ -546,14 +592,21 @@ def greedy_search(
|
|||||||
t += 1
|
t += 1
|
||||||
hyp = hyp[context_size:] # remove blanks
|
hyp = hyp[context_size:] # remove blanks
|
||||||
|
|
||||||
return hyp
|
if not return_timestamps:
|
||||||
|
return hyp
|
||||||
|
else:
|
||||||
|
return DecodingResults(
|
||||||
|
tokens=[hyp],
|
||||||
|
timestamps=[timestamp],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def greedy_search_batch(
|
def greedy_search_batch(
|
||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
encoder_out_lens: torch.Tensor,
|
encoder_out_lens: torch.Tensor,
|
||||||
) -> List[List[int]]:
|
return_timestamps: bool = False,
|
||||||
|
) -> Union[List[List[int]], DecodingResults]:
|
||||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||||
Args:
|
Args:
|
||||||
model:
|
model:
|
||||||
@ -563,9 +616,12 @@ def greedy_search_batch(
|
|||||||
encoder_out_lens:
|
encoder_out_lens:
|
||||||
A 1-D tensor of shape (N,), containing number of valid frames in
|
A 1-D tensor of shape (N,), containing number of valid frames in
|
||||||
encoder_out before padding.
|
encoder_out before padding.
|
||||||
|
return_timestamps:
|
||||||
|
Whether to return timestamps.
|
||||||
Returns:
|
Returns:
|
||||||
Return a list-of-list of token IDs containing the decoded results.
|
If return_timestamps is False, return the decoded result.
|
||||||
len(ans) equals to encoder_out.size(0).
|
Else, return a DecodingResults object containing
|
||||||
|
decoded result and corresponding timestamps.
|
||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == 3
|
assert encoder_out.ndim == 3
|
||||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
@ -590,6 +646,10 @@ def greedy_search_batch(
|
|||||||
|
|
||||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||||
|
|
||||||
|
# timestamp[n][i] is the frame index after subsampling
|
||||||
|
# on which hyp[n][i] is decoded
|
||||||
|
timestamps = [[] for _ in range(N)]
|
||||||
|
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
hyps,
|
hyps,
|
||||||
device=device,
|
device=device,
|
||||||
@ -603,7 +663,7 @@ def greedy_search_batch(
|
|||||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
for batch_size in batch_size_list:
|
for (t, batch_size) in enumerate(batch_size_list):
|
||||||
start = offset
|
start = offset
|
||||||
end = offset + batch_size
|
end = offset + batch_size
|
||||||
current_encoder_out = encoder_out.data[start:end]
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
@ -625,6 +685,7 @@ def greedy_search_batch(
|
|||||||
for i, v in enumerate(y):
|
for i, v in enumerate(y):
|
||||||
if v not in (blank_id, unk_id):
|
if v not in (blank_id, unk_id):
|
||||||
hyps[i].append(v)
|
hyps[i].append(v)
|
||||||
|
timestamps[i].append(t)
|
||||||
emitted = True
|
emitted = True
|
||||||
if emitted:
|
if emitted:
|
||||||
# update decoder output
|
# update decoder output
|
||||||
@ -639,11 +700,19 @@ def greedy_search_batch(
|
|||||||
|
|
||||||
sorted_ans = [h[context_size:] for h in hyps]
|
sorted_ans = [h[context_size:] for h in hyps]
|
||||||
ans = []
|
ans = []
|
||||||
|
ans_timestamps = []
|
||||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
for i in range(N):
|
for i in range(N):
|
||||||
ans.append(sorted_ans[unsorted_indices[i]])
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
ans_timestamps.append(timestamps[unsorted_indices[i]])
|
||||||
|
|
||||||
return ans
|
if not return_timestamps:
|
||||||
|
return ans
|
||||||
|
else:
|
||||||
|
return DecodingResults(
|
||||||
|
tokens=ans,
|
||||||
|
timestamps=ans_timestamps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -656,6 +725,12 @@ class Hypothesis:
|
|||||||
# It contains only one entry.
|
# It contains only one entry.
|
||||||
log_prob: torch.Tensor
|
log_prob: torch.Tensor
|
||||||
|
|
||||||
|
# timestamp[i] is the frame index after subsampling
|
||||||
|
# on which ys[i] is decoded
|
||||||
|
timestamp: List[int]
|
||||||
|
|
||||||
|
state_cost: Optional[NgramLmStateCost] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def key(self) -> str:
|
def key(self) -> str:
|
||||||
"""Return a string representation of self.ys"""
|
"""Return a string representation of self.ys"""
|
||||||
@ -803,7 +878,8 @@ def modified_beam_search(
|
|||||||
encoder_out_lens: torch.Tensor,
|
encoder_out_lens: torch.Tensor,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
) -> List[List[int]]:
|
return_timestamps: bool = False,
|
||||||
|
) -> Union[List[List[int]], DecodingResults]:
|
||||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -818,9 +894,12 @@ def modified_beam_search(
|
|||||||
Number of active paths during the beam search.
|
Number of active paths during the beam search.
|
||||||
temperature:
|
temperature:
|
||||||
Softmax temperature.
|
Softmax temperature.
|
||||||
|
return_timestamps:
|
||||||
|
Whether to return timestamps.
|
||||||
Returns:
|
Returns:
|
||||||
Return a list-of-list of token IDs. ans[i] is the decoding results
|
If return_timestamps is False, return the decoded result.
|
||||||
for the i-th utterance.
|
Else, return a DecodingResults object containing
|
||||||
|
decoded result and corresponding timestamps.
|
||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == 3, encoder_out.shape
|
assert encoder_out.ndim == 3, encoder_out.shape
|
||||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
@ -848,6 +927,7 @@ def modified_beam_search(
|
|||||||
Hypothesis(
|
Hypothesis(
|
||||||
ys=[blank_id] * context_size,
|
ys=[blank_id] * context_size,
|
||||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
|
timestamp=[],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -855,7 +935,7 @@ def modified_beam_search(
|
|||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
finalized_B = []
|
finalized_B = []
|
||||||
for batch_size in batch_size_list:
|
for (t, batch_size) in enumerate(batch_size_list):
|
||||||
start = offset
|
start = offset
|
||||||
end = offset + batch_size
|
end = offset + batch_size
|
||||||
current_encoder_out = encoder_out.data[start:end]
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
@ -933,30 +1013,44 @@ def modified_beam_search(
|
|||||||
|
|
||||||
new_ys = hyp.ys[:]
|
new_ys = hyp.ys[:]
|
||||||
new_token = topk_token_indexes[k]
|
new_token = topk_token_indexes[k]
|
||||||
|
new_timestamp = hyp.timestamp[:]
|
||||||
if new_token not in (blank_id, unk_id):
|
if new_token not in (blank_id, unk_id):
|
||||||
new_ys.append(new_token)
|
new_ys.append(new_token)
|
||||||
|
new_timestamp.append(t)
|
||||||
|
|
||||||
new_log_prob = topk_log_probs[k]
|
new_log_prob = topk_log_probs[k]
|
||||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
new_hyp = Hypothesis(
|
||||||
|
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
|
||||||
|
)
|
||||||
B[i].add(new_hyp)
|
B[i].add(new_hyp)
|
||||||
|
|
||||||
B = B + finalized_B
|
B = B + finalized_B
|
||||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||||
|
|
||||||
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||||
|
sorted_timestamps = [h.timestamp for h in best_hyps]
|
||||||
ans = []
|
ans = []
|
||||||
|
ans_timestamps = []
|
||||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
for i in range(N):
|
for i in range(N):
|
||||||
ans.append(sorted_ans[unsorted_indices[i]])
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
|
||||||
|
|
||||||
return ans
|
if not return_timestamps:
|
||||||
|
return ans
|
||||||
|
else:
|
||||||
|
return DecodingResults(
|
||||||
|
tokens=ans,
|
||||||
|
timestamps=ans_timestamps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _deprecated_modified_beam_search(
|
def _deprecated_modified_beam_search(
|
||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
) -> List[int]:
|
return_timestamps: bool = False,
|
||||||
|
) -> Union[List[int], DecodingResults]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
It decodes only one utterance at a time. We keep it only for reference.
|
It decodes only one utterance at a time. We keep it only for reference.
|
||||||
@ -971,8 +1065,13 @@ def _deprecated_modified_beam_search(
|
|||||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||||
beam:
|
beam:
|
||||||
Beam size.
|
Beam size.
|
||||||
|
return_timestamps:
|
||||||
|
Whether to return timestamps.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result.
|
If return_timestamps is False, return the decoded result.
|
||||||
|
Else, return a DecodingResults object containing
|
||||||
|
decoded result and corresponding timestamps.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert encoder_out.ndim == 3
|
assert encoder_out.ndim == 3
|
||||||
@ -992,6 +1091,7 @@ def _deprecated_modified_beam_search(
|
|||||||
Hypothesis(
|
Hypothesis(
|
||||||
ys=[blank_id] * context_size,
|
ys=[blank_id] * context_size,
|
||||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
|
timestamp=[],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||||
@ -1050,17 +1150,24 @@ def _deprecated_modified_beam_search(
|
|||||||
for i in range(len(topk_hyp_indexes)):
|
for i in range(len(topk_hyp_indexes)):
|
||||||
hyp = A[topk_hyp_indexes[i]]
|
hyp = A[topk_hyp_indexes[i]]
|
||||||
new_ys = hyp.ys[:]
|
new_ys = hyp.ys[:]
|
||||||
|
new_timestamp = hyp.timestamp[:]
|
||||||
new_token = topk_token_indexes[i]
|
new_token = topk_token_indexes[i]
|
||||||
if new_token not in (blank_id, unk_id):
|
if new_token not in (blank_id, unk_id):
|
||||||
new_ys.append(new_token)
|
new_ys.append(new_token)
|
||||||
|
new_timestamp.append(t)
|
||||||
new_log_prob = topk_log_probs[i]
|
new_log_prob = topk_log_probs[i]
|
||||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
new_hyp = Hypothesis(
|
||||||
|
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
|
||||||
|
)
|
||||||
B.add(new_hyp)
|
B.add(new_hyp)
|
||||||
|
|
||||||
best_hyp = B.get_most_probable(length_norm=True)
|
best_hyp = B.get_most_probable(length_norm=True)
|
||||||
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
||||||
|
|
||||||
return ys
|
if not return_timestamps:
|
||||||
|
return ys
|
||||||
|
else:
|
||||||
|
return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp])
|
||||||
|
|
||||||
|
|
||||||
def beam_search(
|
def beam_search(
|
||||||
@ -1068,7 +1175,8 @@ def beam_search(
|
|||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
) -> List[int]:
|
return_timestamps: bool = False,
|
||||||
|
) -> Union[List[int], DecodingResults]:
|
||||||
"""
|
"""
|
||||||
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
|
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
|
||||||
|
|
||||||
@ -1083,8 +1191,13 @@ def beam_search(
|
|||||||
Beam size.
|
Beam size.
|
||||||
temperature:
|
temperature:
|
||||||
Softmax temperature.
|
Softmax temperature.
|
||||||
|
return_timestamps:
|
||||||
|
Whether to return timestamps.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result.
|
If return_timestamps is False, return the decoded result.
|
||||||
|
Else, return a DecodingResults object containing
|
||||||
|
decoded result and corresponding timestamps.
|
||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == 3
|
assert encoder_out.ndim == 3
|
||||||
|
|
||||||
@ -1111,7 +1224,7 @@ def beam_search(
|
|||||||
t = 0
|
t = 0
|
||||||
|
|
||||||
B = HypothesisList()
|
B = HypothesisList()
|
||||||
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0))
|
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[]))
|
||||||
|
|
||||||
max_sym_per_utt = 20000
|
max_sym_per_utt = 20000
|
||||||
|
|
||||||
@ -1172,7 +1285,13 @@ def beam_search(
|
|||||||
new_y_star_log_prob = y_star.log_prob + skip_log_prob
|
new_y_star_log_prob = y_star.log_prob + skip_log_prob
|
||||||
|
|
||||||
# ys[:] returns a copy of ys
|
# ys[:] returns a copy of ys
|
||||||
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
|
B.add(
|
||||||
|
Hypothesis(
|
||||||
|
ys=y_star.ys[:],
|
||||||
|
log_prob=new_y_star_log_prob,
|
||||||
|
timestamp=y_star.timestamp[:],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Second, process other non-blank labels
|
# Second, process other non-blank labels
|
||||||
values, indices = log_prob.topk(beam + 1)
|
values, indices = log_prob.topk(beam + 1)
|
||||||
@ -1181,7 +1300,14 @@ def beam_search(
|
|||||||
continue
|
continue
|
||||||
new_ys = y_star.ys + [i]
|
new_ys = y_star.ys + [i]
|
||||||
new_log_prob = y_star.log_prob + v
|
new_log_prob = y_star.log_prob + v
|
||||||
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
|
new_timestamp = y_star.timestamp + [t]
|
||||||
|
A.add(
|
||||||
|
Hypothesis(
|
||||||
|
ys=new_ys,
|
||||||
|
log_prob=new_log_prob,
|
||||||
|
timestamp=new_timestamp,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Check whether B contains more than "beam" elements more probable
|
# Check whether B contains more than "beam" elements more probable
|
||||||
# than the most probable in A
|
# than the most probable in A
|
||||||
@ -1197,7 +1323,11 @@ def beam_search(
|
|||||||
|
|
||||||
best_hyp = B.get_most_probable(length_norm=True)
|
best_hyp = B.get_most_probable(length_norm=True)
|
||||||
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
||||||
return ys
|
|
||||||
|
if not return_timestamps:
|
||||||
|
return ys
|
||||||
|
else:
|
||||||
|
return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp])
|
||||||
|
|
||||||
|
|
||||||
def fast_beam_search_with_nbest_rescoring(
|
def fast_beam_search_with_nbest_rescoring(
|
||||||
@ -1217,7 +1347,8 @@ def fast_beam_search_with_nbest_rescoring(
|
|||||||
use_double_scores: bool = True,
|
use_double_scores: bool = True,
|
||||||
nbest_scale: float = 0.5,
|
nbest_scale: float = 0.5,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
) -> Dict[str, List[List[int]]]:
|
return_timestamps: bool = False,
|
||||||
|
) -> Dict[str, Union[List[List[int]], DecodingResults]]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
A lattice is first obtained using fast beam search, num_path are selected
|
A lattice is first obtained using fast beam search, num_path are selected
|
||||||
and rescored using a given language model. The shortest path within the
|
and rescored using a given language model. The shortest path within the
|
||||||
@ -1259,10 +1390,13 @@ def fast_beam_search_with_nbest_rescoring(
|
|||||||
yields more unique paths.
|
yields more unique paths.
|
||||||
temperature:
|
temperature:
|
||||||
Softmax temperature.
|
Softmax temperature.
|
||||||
|
return_timestamps:
|
||||||
|
Whether to return timestamps.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result in a dict, where the key has the form
|
Return the decoded result in a dict, where the key has the form
|
||||||
'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the
|
'ngram_lm_scale_xx' and the value is the decoded results
|
||||||
ngram LM scale value used during decoding, i.e., 0.1.
|
optionally with timestamps. `xx` is the ngram LM scale value
|
||||||
|
used during decoding, i.e., 0.1.
|
||||||
"""
|
"""
|
||||||
lattice = fast_beam_search(
|
lattice = fast_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -1340,16 +1474,18 @@ def fast_beam_search_with_nbest_rescoring(
|
|||||||
log_semiring=False,
|
log_semiring=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
ans: Dict[str, List[List[int]]] = {}
|
ans: Dict[str, Union[List[List[int]], DecodingResults]] = {}
|
||||||
for s in ngram_lm_scale_list:
|
for s in ngram_lm_scale_list:
|
||||||
key = f"ngram_lm_scale_{s}"
|
key = f"ngram_lm_scale_{s}"
|
||||||
tot_scores = am_scores.values + s * ngram_lm_scores
|
tot_scores = am_scores.values + s * ngram_lm_scores
|
||||||
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||||
max_indexes = ragged_tot_scores.argmax()
|
max_indexes = ragged_tot_scores.argmax()
|
||||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||||
hyps = get_texts(best_path)
|
|
||||||
|
|
||||||
ans[key] = hyps
|
if not return_timestamps:
|
||||||
|
ans[key] = get_texts(best_path)
|
||||||
|
else:
|
||||||
|
ans[key] = get_texts_with_timestamp(best_path)
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
@ -1373,7 +1509,8 @@ def fast_beam_search_with_nbest_rnn_rescoring(
|
|||||||
use_double_scores: bool = True,
|
use_double_scores: bool = True,
|
||||||
nbest_scale: float = 0.5,
|
nbest_scale: float = 0.5,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
) -> Dict[str, List[List[int]]]:
|
return_timestamps: bool = False,
|
||||||
|
) -> Dict[str, Union[List[List[int]], DecodingResults]]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
A lattice is first obtained using fast beam search, num_path are selected
|
A lattice is first obtained using fast beam search, num_path are selected
|
||||||
and rescored using a given language model and a rnn-lm.
|
and rescored using a given language model and a rnn-lm.
|
||||||
@ -1419,10 +1556,13 @@ def fast_beam_search_with_nbest_rnn_rescoring(
|
|||||||
yields more unique paths.
|
yields more unique paths.
|
||||||
temperature:
|
temperature:
|
||||||
Softmax temperature.
|
Softmax temperature.
|
||||||
|
return_timestamps:
|
||||||
|
Whether to return timestamps.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result in a dict, where the key has the form
|
Return the decoded result in a dict, where the key has the form
|
||||||
'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the
|
'ngram_lm_scale_xx' and the value is the decoded results
|
||||||
ngram LM scale value used during decoding, i.e., 0.1.
|
optionally with timestamps. `xx` is the ngram LM scale value
|
||||||
|
used during decoding, i.e., 0.1.
|
||||||
"""
|
"""
|
||||||
lattice = fast_beam_search(
|
lattice = fast_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -1534,8 +1674,180 @@ def fast_beam_search_with_nbest_rnn_rescoring(
|
|||||||
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||||
max_indexes = ragged_tot_scores.argmax()
|
max_indexes = ragged_tot_scores.argmax()
|
||||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||||
hyps = get_texts(best_path)
|
|
||||||
|
|
||||||
ans[key] = hyps
|
if not return_timestamps:
|
||||||
|
ans[key] = get_texts(best_path)
|
||||||
|
else:
|
||||||
|
ans[key] = get_texts_with_timestamp(best_path)
|
||||||
|
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def modified_beam_search_ngram_rescoring(
|
||||||
|
model: Transducer,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
ngram_lm: NgramLm,
|
||||||
|
ngram_lm_scale: float,
|
||||||
|
beam: int = 4,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
The transducer model.
|
||||||
|
encoder_out:
|
||||||
|
Output from the encoder. Its shape is (N, T, C).
|
||||||
|
encoder_out_lens:
|
||||||
|
A 1-D tensor of shape (N,), containing number of valid frames in
|
||||||
|
encoder_out before padding.
|
||||||
|
beam:
|
||||||
|
Number of active paths during the beam search.
|
||||||
|
temperature:
|
||||||
|
Softmax temperature.
|
||||||
|
Returns:
|
||||||
|
Return a list-of-list of token IDs. ans[i] is the decoding results
|
||||||
|
for the i-th utterance.
|
||||||
|
"""
|
||||||
|
assert encoder_out.ndim == 3, encoder_out.shape
|
||||||
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
|
|
||||||
|
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||||
|
input=encoder_out,
|
||||||
|
lengths=encoder_out_lens.cpu(),
|
||||||
|
batch_first=True,
|
||||||
|
enforce_sorted=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
blank_id = model.decoder.blank_id
|
||||||
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
lm_scale = ngram_lm_scale
|
||||||
|
|
||||||
|
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||||
|
N = encoder_out.size(0)
|
||||||
|
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||||
|
assert N == batch_size_list[0], (N, batch_size_list)
|
||||||
|
|
||||||
|
B = [HypothesisList() for _ in range(N)]
|
||||||
|
for i in range(N):
|
||||||
|
B[i].add(
|
||||||
|
Hypothesis(
|
||||||
|
ys=[blank_id] * context_size,
|
||||||
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
|
state_cost=NgramLmStateCost(ngram_lm),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
finalized_B = []
|
||||||
|
for batch_size in batch_size_list:
|
||||||
|
start = offset
|
||||||
|
end = offset + batch_size
|
||||||
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
|
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||||
|
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
||||||
|
offset = end
|
||||||
|
|
||||||
|
finalized_B = B[batch_size:] + finalized_B
|
||||||
|
B = B[:batch_size]
|
||||||
|
|
||||||
|
hyps_shape = get_hyps_shape(B).to(device)
|
||||||
|
|
||||||
|
A = [list(b) for b in B]
|
||||||
|
B = [HypothesisList() for _ in range(batch_size)]
|
||||||
|
|
||||||
|
ys_log_probs = torch.cat(
|
||||||
|
[
|
||||||
|
hyp.log_prob.reshape(1, 1) + hyp.state_cost.lm_score * lm_scale
|
||||||
|
for hyps in A
|
||||||
|
for hyp in hyps
|
||||||
|
]
|
||||||
|
) # (num_hyps, 1)
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
) # (num_hyps, context_size)
|
||||||
|
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
||||||
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
|
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
|
||||||
|
|
||||||
|
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
||||||
|
# as index, so we use `to(torch.int64)` below.
|
||||||
|
current_encoder_out = torch.index_select(
|
||||||
|
current_encoder_out,
|
||||||
|
dim=0,
|
||||||
|
index=hyps_shape.row_ids(1).to(torch.int64),
|
||||||
|
) # (num_hyps, 1, 1, encoder_out_dim)
|
||||||
|
|
||||||
|
logits = model.joiner(
|
||||||
|
current_encoder_out,
|
||||||
|
decoder_out,
|
||||||
|
project_input=False,
|
||||||
|
) # (num_hyps, 1, 1, vocab_size)
|
||||||
|
|
||||||
|
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
|
||||||
|
|
||||||
|
log_probs = (logits / temperature).log_softmax(
|
||||||
|
dim=-1
|
||||||
|
) # (num_hyps, vocab_size)
|
||||||
|
|
||||||
|
log_probs.add_(ys_log_probs)
|
||||||
|
vocab_size = log_probs.size(-1)
|
||||||
|
log_probs = log_probs.reshape(-1)
|
||||||
|
|
||||||
|
row_splits = hyps_shape.row_splits(1) * vocab_size
|
||||||
|
log_probs_shape = k2.ragged.create_ragged_shape2(
|
||||||
|
row_splits=row_splits, cached_tot_size=log_probs.numel()
|
||||||
|
)
|
||||||
|
ragged_log_probs = k2.RaggedTensor(
|
||||||
|
shape=log_probs_shape, value=log_probs
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||||
|
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||||
|
|
||||||
|
for k in range(len(topk_hyp_indexes)):
|
||||||
|
hyp_idx = topk_hyp_indexes[k]
|
||||||
|
hyp = A[i][hyp_idx]
|
||||||
|
|
||||||
|
new_ys = hyp.ys[:]
|
||||||
|
new_token = topk_token_indexes[k]
|
||||||
|
if new_token not in (blank_id, unk_id):
|
||||||
|
new_ys.append(new_token)
|
||||||
|
state_cost = hyp.state_cost.forward_one_step(new_token)
|
||||||
|
else:
|
||||||
|
state_cost = hyp.state_cost
|
||||||
|
|
||||||
|
# We only keep AM scores in new_hyp.log_prob
|
||||||
|
new_log_prob = (
|
||||||
|
topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
new_hyp = Hypothesis(
|
||||||
|
ys=new_ys, log_prob=new_log_prob, state_cost=state_cost
|
||||||
|
)
|
||||||
|
B[i].add(new_hyp)
|
||||||
|
|
||||||
|
B = B + finalized_B
|
||||||
|
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||||
|
|
||||||
|
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||||
|
ans = []
|
||||||
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
|
for i in range(N):
|
||||||
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
@ -380,14 +380,13 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
feature_lens += params.left_context
|
|
||||||
feature = torch.nn.functional.pad(
|
|
||||||
feature,
|
|
||||||
pad=(0, 0, 0, params.left_context),
|
|
||||||
value=LOG_EPS,
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
|
feature_lens += params.left_context
|
||||||
|
feature = torch.nn.functional.pad(
|
||||||
|
feature,
|
||||||
|
pad=(0, 0, 0, params.left_context),
|
||||||
|
value=LOG_EPS,
|
||||||
|
)
|
||||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
|
@ -462,14 +462,13 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
feature_lens += params.left_context
|
|
||||||
feature = torch.nn.functional.pad(
|
|
||||||
feature,
|
|
||||||
pad=(0, 0, 0, params.left_context),
|
|
||||||
value=LOG_EPS,
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
|
feature_lens += params.left_context
|
||||||
|
feature = torch.nn.functional.pad(
|
||||||
|
feature,
|
||||||
|
pad=(0, 0, 0, params.left_context),
|
||||||
|
value=LOG_EPS,
|
||||||
|
)
|
||||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
|
@ -24,6 +24,11 @@ with the given torchscript model for the same input.
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from icefall import is_module_available
|
||||||
|
|
||||||
|
if not is_module_available("onnxruntime"):
|
||||||
|
raise ValueError("Please 'pip install onnxruntime' first.")
|
||||||
|
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -21,6 +21,11 @@ This file is to test that models can be exported to onnx.
|
|||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from icefall import is_module_available
|
||||||
|
|
||||||
|
if not is_module_available("onnxruntime"):
|
||||||
|
raise ValueError("Please 'pip install onnxruntime' first.")
|
||||||
|
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
import torch
|
import torch
|
||||||
from conformer import (
|
from conformer import (
|
||||||
|
@ -328,6 +328,7 @@ def get_parser():
|
|||||||
help="The probability to select a batch from the GigaSpeech dataset",
|
help="The probability to select a batch from the GigaSpeech dataset",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
@ -106,6 +106,22 @@ Usage:
|
|||||||
--beam 20.0 \
|
--beam 20.0 \
|
||||||
--max-contexts 8 \
|
--max-contexts 8 \
|
||||||
--max-states 64
|
--max-states 64
|
||||||
|
|
||||||
|
To evaluate symbol delay, you should:
|
||||||
|
(1) Generate cuts with word-time alignments:
|
||||||
|
./local/add_alignment_librispeech.py \
|
||||||
|
--alignments-dir data/alignment \
|
||||||
|
--cuts-in-dir data/fbank \
|
||||||
|
--cuts-out-dir data/fbank_ali
|
||||||
|
(2) Set the argument "--manifest-dir data/fbank_ali" while decoding.
|
||||||
|
For example:
|
||||||
|
./pruned_transducer_stateless4/decode.py \
|
||||||
|
--epoch 40 \
|
||||||
|
--avg 20 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless4/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method greedy_search \
|
||||||
|
--manifest-dir data/fbank_ali
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -142,10 +158,12 @@ from icefall.checkpoint import (
|
|||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
|
DecodingResults,
|
||||||
|
parse_hyp_and_timestamp,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts_and_timestamps,
|
||||||
str2bool,
|
str2bool,
|
||||||
write_error_stats,
|
write_error_stats_with_timestamps,
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG_EPS = math.log(1e-10)
|
LOG_EPS = math.log(1e-10)
|
||||||
@ -318,7 +336,7 @@ def get_parser():
|
|||||||
"--left-context",
|
"--left-context",
|
||||||
type=int,
|
type=int,
|
||||||
default=64,
|
default=64,
|
||||||
help="left context can be seen during decoding (in frames after subsampling)",
|
help="left context can be seen during decoding (in frames after subsampling)", # noqa
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -350,7 +368,7 @@ def decode_one_batch(
|
|||||||
batch: dict,
|
batch: dict,
|
||||||
word_table: Optional[k2.SymbolTable] = None,
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
|
|
||||||
@ -358,9 +376,10 @@ def decode_one_batch(
|
|||||||
if greedy_search is used, it would be "greedy_search"
|
if greedy_search is used, it would be "greedy_search"
|
||||||
If beam search with a beam size of 7 is used, it would be
|
If beam search with a beam size of 7 is used, it would be
|
||||||
"beam_7"
|
"beam_7"
|
||||||
- value: It contains the decoding result. `len(value)` equals to
|
- value: It is a tuple. `len(value[0])` and `len(value[1])` are both
|
||||||
batch size. `value[i]` is the decoding result for the i-th
|
equal to the batch size. `value[0][i]` and `value[1][i]`
|
||||||
utterance in the given batch.
|
are the decoding result and timestamps for the i-th utterance
|
||||||
|
in the given batch respectively.
|
||||||
Args:
|
Args:
|
||||||
params:
|
params:
|
||||||
It's the return value of :func:`get_params`.
|
It's the return value of :func:`get_params`.
|
||||||
@ -379,8 +398,8 @@ def decode_one_batch(
|
|||||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoding result. See above description for the format of
|
Return the decoding result and timestamps. See above description for the
|
||||||
the returned dict.
|
format of the returned dict.
|
||||||
"""
|
"""
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
@ -392,14 +411,13 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
feature_lens += params.left_context
|
|
||||||
feature = torch.nn.functional.pad(
|
|
||||||
feature,
|
|
||||||
pad=(0, 0, 0, params.left_context),
|
|
||||||
value=LOG_EPS,
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
|
feature_lens += params.left_context
|
||||||
|
feature = torch.nn.functional.pad(
|
||||||
|
feature,
|
||||||
|
pad=(0, 0, 0, params.left_context),
|
||||||
|
value=LOG_EPS,
|
||||||
|
)
|
||||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
@ -412,10 +430,8 @@ def decode_one_batch(
|
|||||||
x=feature, x_lens=feature_lens
|
x=feature, x_lens=feature_lens
|
||||||
)
|
)
|
||||||
|
|
||||||
hyps = []
|
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
res = fast_beam_search_one_best(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -423,11 +439,10 @@ def decode_one_batch(
|
|||||||
beam=params.beam,
|
beam=params.beam,
|
||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||||
hyp_tokens = fast_beam_search_nbest_LG(
|
res = fast_beam_search_nbest_LG(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -437,11 +452,10 @@ def decode_one_batch(
|
|||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
nbest_scale=params.nbest_scale,
|
nbest_scale=params.nbest_scale,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
for hyp in hyp_tokens:
|
|
||||||
hyps.append([word_table[i] for i in hyp])
|
|
||||||
elif params.decoding_method == "fast_beam_search_nbest":
|
elif params.decoding_method == "fast_beam_search_nbest":
|
||||||
hyp_tokens = fast_beam_search_nbest(
|
res = fast_beam_search_nbest(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -451,11 +465,10 @@ def decode_one_batch(
|
|||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
nbest_scale=params.nbest_scale,
|
nbest_scale=params.nbest_scale,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
res = fast_beam_search_nbest_oracle(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -466,56 +479,67 @@ def decode_one_batch(
|
|||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
ref_texts=sp.encode(supervisions["text"]),
|
ref_texts=sp.encode(supervisions["text"]),
|
||||||
nbest_scale=params.nbest_scale,
|
nbest_scale=params.nbest_scale,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
elif (
|
elif (
|
||||||
params.decoding_method == "greedy_search"
|
params.decoding_method == "greedy_search"
|
||||||
and params.max_sym_per_frame == 1
|
and params.max_sym_per_frame == 1
|
||||||
):
|
):
|
||||||
hyp_tokens = greedy_search_batch(
|
res = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
elif params.decoding_method == "modified_beam_search":
|
elif params.decoding_method == "modified_beam_search":
|
||||||
hyp_tokens = modified_beam_search(
|
res = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
else:
|
else:
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
|
tokens = []
|
||||||
|
timestamps = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
|
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
hyp = greedy_search(
|
res = greedy_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out_i,
|
encoder_out=encoder_out_i,
|
||||||
max_sym_per_frame=params.max_sym_per_frame,
|
max_sym_per_frame=params.max_sym_per_frame,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
elif params.decoding_method == "beam_search":
|
elif params.decoding_method == "beam_search":
|
||||||
hyp = beam_search(
|
res = beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out_i,
|
encoder_out=encoder_out_i,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
)
|
)
|
||||||
hyps.append(sp.decode(hyp).split())
|
tokens.extend(res.tokens)
|
||||||
|
timestamps.extend(res.timestamps)
|
||||||
|
res = DecodingResults(tokens=tokens, timestamps=timestamps)
|
||||||
|
|
||||||
|
hyps, timestamps = parse_hyp_and_timestamp(
|
||||||
|
decoding_method=params.decoding_method,
|
||||||
|
res=res,
|
||||||
|
sp=sp,
|
||||||
|
subsampling_factor=params.subsampling_factor,
|
||||||
|
frame_shift_ms=params.frame_shift_ms,
|
||||||
|
word_table=word_table,
|
||||||
|
)
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
return {"greedy_search": hyps}
|
return {"greedy_search": (hyps, timestamps)}
|
||||||
elif "fast_beam_search" in params.decoding_method:
|
elif "fast_beam_search" in params.decoding_method:
|
||||||
key = f"beam_{params.beam}_"
|
key = f"beam_{params.beam}_"
|
||||||
key += f"max_contexts_{params.max_contexts}_"
|
key += f"max_contexts_{params.max_contexts}_"
|
||||||
@ -526,9 +550,9 @@ def decode_one_batch(
|
|||||||
if "LG" in params.decoding_method:
|
if "LG" in params.decoding_method:
|
||||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||||
|
|
||||||
return {key: hyps}
|
return {key: (hyps, timestamps)}
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}": hyps}
|
return {f"beam_size_{params.beam_size}": (hyps, timestamps)}
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
@ -538,7 +562,9 @@ def decode_dataset(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
word_table: Optional[k2.SymbolTable] = None,
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
) -> Dict[
|
||||||
|
str, List[Tuple[str, List[str], List[str], List[float], List[float]]]
|
||||||
|
]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -559,9 +585,12 @@ def decode_dataset(
|
|||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
Its value is a list of tuples. Each tuple contains two elements:
|
Its value is a list of tuples. Each tuple contains five elements:
|
||||||
The first is the reference transcript, and the second is the
|
- cut_id
|
||||||
predicted result.
|
- reference transcript
|
||||||
|
- predicted result
|
||||||
|
- timestamp of reference transcript
|
||||||
|
- timestamp of predicted result
|
||||||
"""
|
"""
|
||||||
num_cuts = 0
|
num_cuts = 0
|
||||||
|
|
||||||
@ -580,6 +609,18 @@ def decode_dataset(
|
|||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
|
timestamps_ref = []
|
||||||
|
for cut in batch["supervisions"]["cut"]:
|
||||||
|
for s in cut.supervisions:
|
||||||
|
time = []
|
||||||
|
if s.alignment is not None and "word" in s.alignment:
|
||||||
|
time = [
|
||||||
|
aliword.start
|
||||||
|
for aliword in s.alignment["word"]
|
||||||
|
if aliword.symbol != ""
|
||||||
|
]
|
||||||
|
timestamps_ref.append(time)
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -589,12 +630,18 @@ def decode_dataset(
|
|||||||
batch=batch,
|
batch=batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, hyps in hyps_dict.items():
|
for name, (hyps, timestamps_hyp) in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts) and len(timestamps_hyp) == len(
|
||||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
timestamps_ref
|
||||||
|
)
|
||||||
|
for cut_id, hyp_words, ref_text, time_hyp, time_ref in zip(
|
||||||
|
cut_ids, hyps, texts, timestamps_hyp, timestamps_ref
|
||||||
|
):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((cut_id, ref_words, hyp_words))
|
this_batch.append(
|
||||||
|
(cut_id, ref_words, hyp_words, time_ref, time_hyp)
|
||||||
|
)
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -612,15 +659,19 @@ def decode_dataset(
|
|||||||
def save_results(
|
def save_results(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
results_dict: Dict[str, List[str, Tuple[List[str], List[str]]]],
|
results_dict: Dict[
|
||||||
|
str,
|
||||||
|
List[Tuple[List[str], List[str], List[str], List[float], List[float]]],
|
||||||
|
],
|
||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
|
test_set_delays = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts_and_timestamps(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
@ -629,10 +680,11 @@ def save_results(
|
|||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer, mean_delay, var_delay = write_error_stats_with_timestamps(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
)
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
test_set_delays[key] = (mean_delay, var_delay)
|
||||||
|
|
||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
@ -646,6 +698,19 @@ def save_results(
|
|||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
print("{}\t{}".format(key, val), file=f)
|
print("{}\t{}".format(key, val), file=f)
|
||||||
|
|
||||||
|
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0])
|
||||||
|
delays_info = (
|
||||||
|
params.res_dir
|
||||||
|
/ f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
|
with open(delays_info, "w") as f:
|
||||||
|
print("settings\tsymbol-delay", file=f)
|
||||||
|
for key, val in test_set_delays:
|
||||||
|
print(
|
||||||
|
"{}\tmean: {}s, variance: {}".format(key, val[0], val[1]),
|
||||||
|
file=f,
|
||||||
|
)
|
||||||
|
|
||||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||||
note = "\tbest for {}".format(test_set_name)
|
note = "\tbest for {}".format(test_set_name)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
@ -653,6 +718,15 @@ def save_results(
|
|||||||
note = ""
|
note = ""
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
|
s = "\nFor {}, symbol-delay of different settings are:\n".format(
|
||||||
|
test_set_name
|
||||||
|
)
|
||||||
|
note = "\tbest for {}".format(test_set_name)
|
||||||
|
for key, val in test_set_delays:
|
||||||
|
s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note)
|
||||||
|
note = ""
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
|
@ -386,6 +386,7 @@ def get_params() -> AttributeDict:
|
|||||||
"""
|
"""
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
|
"frame_shift_ms": 10.0,
|
||||||
"best_train_loss": float("inf"),
|
"best_train_loss": float("inf"),
|
||||||
"best_valid_loss": float("inf"),
|
"best_valid_loss": float("inf"),
|
||||||
"best_train_epoch": -1,
|
"best_train_epoch": -1,
|
||||||
|
@ -378,14 +378,13 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
feature_lens += params.left_context
|
|
||||||
feature = torch.nn.functional.pad(
|
|
||||||
feature,
|
|
||||||
pad=(0, 0, 0, params.left_context),
|
|
||||||
value=LOG_EPS,
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
|
feature_lens += params.left_context
|
||||||
|
feature = torch.nn.functional.pad(
|
||||||
|
feature,
|
||||||
|
pad=(0, 0, 0, params.left_context),
|
||||||
|
value=LOG_EPS,
|
||||||
|
)
|
||||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
|
@ -21,7 +21,6 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from multi_quantization.prediction import JointCodebookLoss
|
|
||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos
|
||||||
@ -74,6 +73,14 @@ class Transducer(nn.Module):
|
|||||||
encoder_dim, vocab_size, initial_speed=0.5
|
encoder_dim, vocab_size, initial_speed=0.5
|
||||||
)
|
)
|
||||||
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
|
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
|
||||||
|
|
||||||
|
from icefall import is_module_available
|
||||||
|
|
||||||
|
if not is_module_available("multi_quantization"):
|
||||||
|
raise ValueError("Please 'pip install multi_quantization' first.")
|
||||||
|
|
||||||
|
from multi_quantization.prediction import JointCodebookLoss
|
||||||
|
|
||||||
if num_codebooks > 0:
|
if num_codebooks > 0:
|
||||||
self.codebook_loss_net = JointCodebookLoss(
|
self.codebook_loss_net = JointCodebookLoss(
|
||||||
predictor_channels=encoder_dim,
|
predictor_channels=encoder_dim,
|
||||||
|
@ -28,18 +28,21 @@ from typing import List, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import multi_quantization as quantization
|
|
||||||
|
|
||||||
|
from icefall import is_module_available
|
||||||
|
|
||||||
|
if not is_module_available("multi_quantization"):
|
||||||
|
raise ValueError("Please 'pip install multi_quantization' first.")
|
||||||
|
|
||||||
|
import multi_quantization as quantization
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from hubert_xlarge import HubertXlargeFineTuned
|
from hubert_xlarge import HubertXlargeFineTuned
|
||||||
from icefall.utils import (
|
|
||||||
AttributeDict,
|
|
||||||
setup_logger,
|
|
||||||
)
|
|
||||||
from lhotse import CutSet, load_manifest
|
from lhotse import CutSet, load_manifest
|
||||||
from lhotse.cut import MonoCut
|
from lhotse.cut import MonoCut
|
||||||
from lhotse.features.io import NumpyHdf5Writer
|
from lhotse.features.io import NumpyHdf5Writer
|
||||||
|
|
||||||
|
from icefall.utils import AttributeDict, setup_logger
|
||||||
|
|
||||||
|
|
||||||
class CodebookIndexExtractor:
|
class CodebookIndexExtractor:
|
||||||
"""
|
"""
|
||||||
|
@ -327,7 +327,7 @@ def decode_dataset(
|
|||||||
def save_results(
|
def save_results(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
results_dict: Dict[str, List[str, Tuple[List[str], List[str]]]],
|
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
|
@ -40,6 +40,11 @@ https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_s
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from icefall import is_module_available
|
||||||
|
|
||||||
|
if not is_module_available("onnxruntime"):
|
||||||
|
raise ValueError("Please 'pip install onnxruntime' first.")
|
||||||
|
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -49,6 +49,12 @@ from typing import List
|
|||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from icefall import is_module_available
|
||||||
|
|
||||||
|
if not is_module_available("onnxruntime"):
|
||||||
|
raise ValueError("Please 'pip install onnxruntime' first.")
|
||||||
|
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
@ -50,6 +50,7 @@ from .utils import (
|
|||||||
get_executor,
|
get_executor,
|
||||||
get_texts,
|
get_texts,
|
||||||
is_jit_tracing,
|
is_jit_tracing,
|
||||||
|
is_module_available,
|
||||||
l1_norm,
|
l1_norm,
|
||||||
l2_norm,
|
l2_norm,
|
||||||
linf_norm,
|
linf_norm,
|
||||||
@ -65,3 +66,5 @@ from .utils import (
|
|||||||
subsequent_chunk_mask,
|
subsequent_chunk_mask,
|
||||||
write_error_stats,
|
write_error_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .ngram_lm import NgramLm, NgramLmStateCost
|
||||||
|
171
icefall/ngram_lm.py
Normal file
171
icefall/ngram_lm.py
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from icefall.utils import is_module_available
|
||||||
|
|
||||||
|
|
||||||
|
class NgramLm:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fst_filename: str,
|
||||||
|
backoff_id: int,
|
||||||
|
is_binary: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
fst_filename:
|
||||||
|
Path to the FST.
|
||||||
|
backoff_id:
|
||||||
|
ID of the backoff symbol.
|
||||||
|
is_binary:
|
||||||
|
True if the given file is a binary FST.
|
||||||
|
"""
|
||||||
|
if not is_module_available("kaldifst"):
|
||||||
|
raise ValueError("Please 'pip install kaldifst' first.")
|
||||||
|
|
||||||
|
import kaldifst
|
||||||
|
|
||||||
|
if is_binary:
|
||||||
|
lm = kaldifst.StdVectorFst.read(fst_filename)
|
||||||
|
else:
|
||||||
|
with open(fst_filename, "r") as f:
|
||||||
|
lm = kaldifst.compile(f.read(), acceptor=False)
|
||||||
|
|
||||||
|
if not lm.is_ilabel_sorted:
|
||||||
|
kaldifst.arcsort(lm, sort_type="ilabel")
|
||||||
|
|
||||||
|
self.lm = lm
|
||||||
|
self.backoff_id = backoff_id
|
||||||
|
|
||||||
|
def _process_backoff_arcs(
|
||||||
|
self,
|
||||||
|
state: int,
|
||||||
|
cost: float,
|
||||||
|
) -> List[Tuple[int, float]]:
|
||||||
|
"""Similar to ProcessNonemitting() from Kaldi, this function
|
||||||
|
returns the list of states reachable from the given state via
|
||||||
|
backoff arcs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state:
|
||||||
|
The input state.
|
||||||
|
cost:
|
||||||
|
The cost of reaching the given state from the start state.
|
||||||
|
Returns:
|
||||||
|
Return a list, where each element contains a tuple with two entries:
|
||||||
|
- next_state
|
||||||
|
- cost of next_state
|
||||||
|
If there is no backoff arc leaving the input state, then return
|
||||||
|
an empty list.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
|
||||||
|
next_state, next_cost = self._get_next_state_and_cost_without_backoff(
|
||||||
|
state=state,
|
||||||
|
label=self.backoff_id,
|
||||||
|
)
|
||||||
|
if next_state is None:
|
||||||
|
return ans
|
||||||
|
ans.append((next_state, next_cost + cost))
|
||||||
|
ans += self._process_backoff_arcs(next_state, next_cost + cost)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def _get_next_state_and_cost_without_backoff(
|
||||||
|
self, state: int, label: int
|
||||||
|
) -> Tuple[int, float]:
|
||||||
|
"""TODO: Add doc."""
|
||||||
|
import kaldifst
|
||||||
|
|
||||||
|
arc_iter = kaldifst.ArcIterator(self.lm, state)
|
||||||
|
num_arcs = self.lm.num_arcs(state)
|
||||||
|
|
||||||
|
# The LM is arc sorted by ilabel, so we use binary search below.
|
||||||
|
left = 0
|
||||||
|
right = num_arcs - 1
|
||||||
|
while left <= right:
|
||||||
|
mid = (left + right) // 2
|
||||||
|
arc_iter.seek(mid)
|
||||||
|
arc = arc_iter.value
|
||||||
|
if arc.ilabel < label:
|
||||||
|
left = mid + 1
|
||||||
|
elif arc.ilabel > label:
|
||||||
|
right = mid - 1
|
||||||
|
else:
|
||||||
|
return arc.nextstate, arc.weight.value
|
||||||
|
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def get_next_state_and_cost(
|
||||||
|
self,
|
||||||
|
state: int,
|
||||||
|
label: int,
|
||||||
|
) -> Tuple[List[int], List[float]]:
|
||||||
|
states = [state]
|
||||||
|
costs = [0]
|
||||||
|
|
||||||
|
extra_states_costs = self._process_backoff_arcs(
|
||||||
|
state=state,
|
||||||
|
cost=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
for s, c in extra_states_costs:
|
||||||
|
states.append(s)
|
||||||
|
costs.append(c)
|
||||||
|
|
||||||
|
next_states = []
|
||||||
|
next_costs = []
|
||||||
|
for s, c in zip(states, costs):
|
||||||
|
ns, nc = self._get_next_state_and_cost_without_backoff(s, label)
|
||||||
|
if ns:
|
||||||
|
next_states.append(ns)
|
||||||
|
next_costs.append(c + nc)
|
||||||
|
|
||||||
|
return next_states, next_costs
|
||||||
|
|
||||||
|
|
||||||
|
class NgramLmStateCost:
|
||||||
|
def __init__(self, ngram_lm: NgramLm, state_cost: Optional[dict] = None):
|
||||||
|
assert ngram_lm.lm.start == 0, ngram_lm.lm.start
|
||||||
|
self.ngram_lm = ngram_lm
|
||||||
|
if state_cost is not None:
|
||||||
|
self.state_cost = state_cost
|
||||||
|
else:
|
||||||
|
self.state_cost = defaultdict(lambda: float("inf"))
|
||||||
|
|
||||||
|
# At the very beginning, we are at the start state with cost 0
|
||||||
|
self.state_cost[0] = 0.0
|
||||||
|
|
||||||
|
def forward_one_step(self, label: int) -> "NgramLmStateCost":
|
||||||
|
state_cost = defaultdict(lambda: float("inf"))
|
||||||
|
for s, c in self.state_cost.items():
|
||||||
|
next_states, next_costs = self.ngram_lm.get_next_state_and_cost(
|
||||||
|
s,
|
||||||
|
label,
|
||||||
|
)
|
||||||
|
for ns, nc in zip(next_states, next_costs):
|
||||||
|
state_cost[ns] = min(state_cost[ns], c + nc)
|
||||||
|
|
||||||
|
return NgramLmStateCost(ngram_lm=self.ngram_lm, state_cost=state_cost)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lm_score(self) -> float:
|
||||||
|
if len(self.state_cost) == 0:
|
||||||
|
return float("-inf")
|
||||||
|
|
||||||
|
return -1 * min(self.state_cost.values())
|
460
icefall/utils.py
460
icefall/utils.py
@ -24,9 +24,10 @@ import re
|
|||||||
import subprocess
|
import subprocess
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Iterable, List, TextIO, Tuple, Union
|
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import k2.version
|
import k2.version
|
||||||
@ -248,6 +249,86 @@ def get_texts(
|
|||||||
return aux_labels.tolist()
|
return aux_labels.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DecodingResults:
|
||||||
|
# Decoded token IDs for each utterance in the batch
|
||||||
|
tokens: List[List[int]]
|
||||||
|
|
||||||
|
# timestamps[i][k] contains the frame number on which tokens[i][k]
|
||||||
|
# is decoded
|
||||||
|
timestamps: List[List[int]]
|
||||||
|
|
||||||
|
# hyps[i] is the recognition results, i.e., word IDs
|
||||||
|
# for the i-th utterance with fast_beam_search_nbest_LG.
|
||||||
|
hyps: Union[List[List[int]], k2.RaggedTensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokens_and_timestamps(labels: List[int]) -> Tuple[List[int], List[int]]:
|
||||||
|
tokens = []
|
||||||
|
timestamps = []
|
||||||
|
for i, v in enumerate(labels):
|
||||||
|
if v != 0:
|
||||||
|
tokens.append(v)
|
||||||
|
timestamps.append(i)
|
||||||
|
|
||||||
|
return tokens, timestamps
|
||||||
|
|
||||||
|
|
||||||
|
def get_texts_with_timestamp(
|
||||||
|
best_paths: k2.Fsa, return_ragged: bool = False
|
||||||
|
) -> DecodingResults:
|
||||||
|
"""Extract the texts (as word IDs) and timestamps from the best-path FSAs.
|
||||||
|
Args:
|
||||||
|
best_paths:
|
||||||
|
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
||||||
|
containing multiple FSAs, which is expected to be the result
|
||||||
|
of k2.shortest_path (otherwise the returned values won't
|
||||||
|
be meaningful).
|
||||||
|
return_ragged:
|
||||||
|
True to return a ragged tensor with two axes [utt][word_id].
|
||||||
|
False to return a list-of-list word IDs.
|
||||||
|
Returns:
|
||||||
|
Returns a list of lists of int, containing the label sequences we
|
||||||
|
decoded.
|
||||||
|
"""
|
||||||
|
if isinstance(best_paths.aux_labels, k2.RaggedTensor):
|
||||||
|
# remove 0's and -1's.
|
||||||
|
aux_labels = best_paths.aux_labels.remove_values_leq(0)
|
||||||
|
# TODO: change arcs.shape() to arcs.shape
|
||||||
|
aux_shape = best_paths.arcs.shape().compose(aux_labels.shape)
|
||||||
|
|
||||||
|
# remove the states and arcs axes.
|
||||||
|
aux_shape = aux_shape.remove_axis(1)
|
||||||
|
aux_shape = aux_shape.remove_axis(1)
|
||||||
|
aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values)
|
||||||
|
else:
|
||||||
|
# remove axis corresponding to states.
|
||||||
|
aux_shape = best_paths.arcs.shape().remove_axis(1)
|
||||||
|
aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels)
|
||||||
|
# remove 0's and -1's.
|
||||||
|
aux_labels = aux_labels.remove_values_leq(0)
|
||||||
|
|
||||||
|
assert aux_labels.num_axes == 2
|
||||||
|
|
||||||
|
labels_shape = best_paths.arcs.shape().remove_axis(1)
|
||||||
|
labels_list = k2.RaggedTensor(
|
||||||
|
labels_shape, best_paths.labels.contiguous()
|
||||||
|
).tolist()
|
||||||
|
|
||||||
|
tokens = []
|
||||||
|
timestamps = []
|
||||||
|
for labels in labels_list:
|
||||||
|
token, time = get_tokens_and_timestamps(labels[:-1])
|
||||||
|
tokens.append(token)
|
||||||
|
timestamps.append(time)
|
||||||
|
|
||||||
|
return DecodingResults(
|
||||||
|
tokens=tokens,
|
||||||
|
timestamps=timestamps,
|
||||||
|
hyps=aux_labels if return_ragged else aux_labels.tolist(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
|
def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
|
||||||
"""Extract labels or aux_labels from the best-path FSAs.
|
"""Extract labels or aux_labels from the best-path FSAs.
|
||||||
|
|
||||||
@ -352,6 +433,33 @@ def store_transcripts(
|
|||||||
print(f"{cut_id}:\thyp={hyp}", file=f)
|
print(f"{cut_id}:\thyp={hyp}", file=f)
|
||||||
|
|
||||||
|
|
||||||
|
def store_transcripts_and_timestamps(
|
||||||
|
filename: Pathlike,
|
||||||
|
texts: Iterable[Tuple[str, List[str], List[str], List[float], List[float]]],
|
||||||
|
) -> None:
|
||||||
|
"""Save predicted results and reference transcripts as well as their timestamps
|
||||||
|
to a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
File to save the results to.
|
||||||
|
texts:
|
||||||
|
An iterable of tuples. The first element is the cur_id, the second is
|
||||||
|
the reference transcript and the third element is the predicted result.
|
||||||
|
Returns:
|
||||||
|
Return None.
|
||||||
|
"""
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
for cut_id, ref, hyp, time_ref, time_hyp in texts:
|
||||||
|
print(f"{cut_id}:\tref={ref}", file=f)
|
||||||
|
print(f"{cut_id}:\thyp={hyp}", file=f)
|
||||||
|
if len(time_ref) > 0:
|
||||||
|
s = "[" + ", ".join(["%0.3f" % i for i in time_ref]) + "]"
|
||||||
|
print(f"{cut_id}:\ttimestamp_ref={s}", file=f)
|
||||||
|
s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]"
|
||||||
|
print(f"{cut_id}:\ttimestamp_hyp={s}", file=f)
|
||||||
|
|
||||||
|
|
||||||
def write_error_stats(
|
def write_error_stats(
|
||||||
f: TextIO,
|
f: TextIO,
|
||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
@ -519,6 +627,211 @@ def write_error_stats(
|
|||||||
return float(tot_err_rate)
|
return float(tot_err_rate)
|
||||||
|
|
||||||
|
|
||||||
|
def write_error_stats_with_timestamps(
|
||||||
|
f: TextIO,
|
||||||
|
test_set_name: str,
|
||||||
|
results: List[Tuple[str, List[str], List[str], List[float], List[float]]],
|
||||||
|
enable_log: bool = True,
|
||||||
|
) -> Tuple[float, float, float]:
|
||||||
|
"""Write statistics based on predicted results and reference transcripts
|
||||||
|
as well as their timestamps.
|
||||||
|
|
||||||
|
It will write the following to the given file:
|
||||||
|
|
||||||
|
- WER
|
||||||
|
- number of insertions, deletions, substitutions, corrects and total
|
||||||
|
reference words. For example::
|
||||||
|
|
||||||
|
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
|
||||||
|
reference words (2337 correct)
|
||||||
|
|
||||||
|
- The difference between the reference transcript and predicted result.
|
||||||
|
An instance is given below::
|
||||||
|
|
||||||
|
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
|
||||||
|
|
||||||
|
The above example shows that the reference word is `EDISON`,
|
||||||
|
but it is predicted to `ADDISON` (a substitution error).
|
||||||
|
|
||||||
|
Another example is::
|
||||||
|
|
||||||
|
FOR THE FIRST DAY (SIR->*) I THINK
|
||||||
|
|
||||||
|
The reference word `SIR` is missing in the predicted
|
||||||
|
results (a deletion error).
|
||||||
|
results:
|
||||||
|
An iterable of tuples. The first element is the cur_id, the second is
|
||||||
|
the reference transcript and the third element is the predicted result.
|
||||||
|
enable_log:
|
||||||
|
If True, also print detailed WER to the console.
|
||||||
|
Otherwise, it is written only to the given file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return total word error rate and mean delay.
|
||||||
|
"""
|
||||||
|
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
||||||
|
ins: Dict[str, int] = defaultdict(int)
|
||||||
|
dels: Dict[str, int] = defaultdict(int)
|
||||||
|
|
||||||
|
# `words` stores counts per word, as follows:
|
||||||
|
# corr, ref_sub, hyp_sub, ins, dels
|
||||||
|
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
||||||
|
num_corr = 0
|
||||||
|
ERR = "*"
|
||||||
|
# Compute mean alignment delay on the correct words
|
||||||
|
all_delay = []
|
||||||
|
for cut_id, ref, hyp, time_ref, time_hyp in results:
|
||||||
|
ali = kaldialign.align(ref, hyp, ERR)
|
||||||
|
has_time_ref = len(time_ref) > 0
|
||||||
|
if has_time_ref:
|
||||||
|
# pointer to timestamp_hyp
|
||||||
|
p_hyp = 0
|
||||||
|
# pointer to timestamp_ref
|
||||||
|
p_ref = 0
|
||||||
|
for ref_word, hyp_word in ali:
|
||||||
|
if ref_word == ERR:
|
||||||
|
ins[hyp_word] += 1
|
||||||
|
words[hyp_word][3] += 1
|
||||||
|
if has_time_ref:
|
||||||
|
p_hyp += 1
|
||||||
|
elif hyp_word == ERR:
|
||||||
|
dels[ref_word] += 1
|
||||||
|
words[ref_word][4] += 1
|
||||||
|
if has_time_ref:
|
||||||
|
p_ref += 1
|
||||||
|
elif hyp_word != ref_word:
|
||||||
|
subs[(ref_word, hyp_word)] += 1
|
||||||
|
words[ref_word][1] += 1
|
||||||
|
words[hyp_word][2] += 1
|
||||||
|
if has_time_ref:
|
||||||
|
p_hyp += 1
|
||||||
|
p_ref += 1
|
||||||
|
else:
|
||||||
|
words[ref_word][0] += 1
|
||||||
|
num_corr += 1
|
||||||
|
if has_time_ref:
|
||||||
|
all_delay.append(time_hyp[p_hyp] - time_ref[p_ref])
|
||||||
|
p_hyp += 1
|
||||||
|
p_ref += 1
|
||||||
|
if has_time_ref:
|
||||||
|
assert p_hyp == len(hyp), (p_hyp, len(hyp))
|
||||||
|
assert p_ref == len(ref), (p_ref, len(ref))
|
||||||
|
|
||||||
|
ref_len = sum([len(r) for _, r, _, _, _ in results])
|
||||||
|
sub_errs = sum(subs.values())
|
||||||
|
ins_errs = sum(ins.values())
|
||||||
|
del_errs = sum(dels.values())
|
||||||
|
tot_errs = sub_errs + ins_errs + del_errs
|
||||||
|
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
||||||
|
|
||||||
|
mean_delay = "inf"
|
||||||
|
var_delay = "inf"
|
||||||
|
num_delay = len(all_delay)
|
||||||
|
if num_delay > 0:
|
||||||
|
mean_delay = sum(all_delay) / num_delay
|
||||||
|
var_delay = sum([(i - mean_delay) ** 2 for i in all_delay]) / num_delay
|
||||||
|
mean_delay = "%.3f" % mean_delay
|
||||||
|
var_delay = "%.3f" % var_delay
|
||||||
|
|
||||||
|
if enable_log:
|
||||||
|
logging.info(
|
||||||
|
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
||||||
|
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
||||||
|
f"{del_errs} del, {sub_errs} sub ]"
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
f"[{test_set_name}] %symbol-delay mean: {mean_delay}s, variance: {var_delay} " # noqa
|
||||||
|
f"computed on {num_delay} correct words"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"%WER = {tot_err_rate}", file=f)
|
||||||
|
print(
|
||||||
|
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
|
||||||
|
f"{sub_errs} substitutions, over {ref_len} reference "
|
||||||
|
f"words ({num_corr} correct)",
|
||||||
|
file=f,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"Search below for sections starting with PER-UTT DETAILS:, "
|
||||||
|
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
|
||||||
|
file=f,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
||||||
|
for cut_id, ref, hyp, _, _ in results:
|
||||||
|
ali = kaldialign.align(ref, hyp, ERR)
|
||||||
|
combine_successive_errors = True
|
||||||
|
if combine_successive_errors:
|
||||||
|
ali = [[[x], [y]] for x, y in ali]
|
||||||
|
for i in range(len(ali) - 1):
|
||||||
|
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
|
||||||
|
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
|
||||||
|
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
|
||||||
|
ali[i] = [[], []]
|
||||||
|
ali = [
|
||||||
|
[
|
||||||
|
list(filter(lambda a: a != ERR, x)),
|
||||||
|
list(filter(lambda a: a != ERR, y)),
|
||||||
|
]
|
||||||
|
for x, y in ali
|
||||||
|
]
|
||||||
|
ali = list(filter(lambda x: x != [[], []], ali))
|
||||||
|
ali = [
|
||||||
|
[
|
||||||
|
ERR if x == [] else " ".join(x),
|
||||||
|
ERR if y == [] else " ".join(y),
|
||||||
|
]
|
||||||
|
for x, y in ali
|
||||||
|
]
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"{cut_id}:\t"
|
||||||
|
+ " ".join(
|
||||||
|
(
|
||||||
|
ref_word
|
||||||
|
if ref_word == hyp_word
|
||||||
|
else f"({ref_word}->{hyp_word})"
|
||||||
|
for ref_word, hyp_word in ali
|
||||||
|
)
|
||||||
|
),
|
||||||
|
file=f,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
||||||
|
|
||||||
|
for count, (ref, hyp) in sorted(
|
||||||
|
[(v, k) for k, v in subs.items()], reverse=True
|
||||||
|
):
|
||||||
|
print(f"{count} {ref} -> {hyp}", file=f)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("DELETIONS: count ref", file=f)
|
||||||
|
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
|
||||||
|
print(f"{count} {ref}", file=f)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("INSERTIONS: count hyp", file=f)
|
||||||
|
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
|
||||||
|
print(f"{count} {hyp}", file=f)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print(
|
||||||
|
"PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f
|
||||||
|
)
|
||||||
|
for _, word, counts in sorted(
|
||||||
|
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
||||||
|
):
|
||||||
|
(corr, ref_sub, hyp_sub, ins, dels) = counts
|
||||||
|
tot_errs = ref_sub + hyp_sub + ins + dels
|
||||||
|
ref_count = corr + ref_sub + dels
|
||||||
|
hyp_count = corr + hyp_sub + ins
|
||||||
|
|
||||||
|
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||||
|
return float(tot_err_rate), float(mean_delay), float(var_delay)
|
||||||
|
|
||||||
|
|
||||||
class MetricsTracker(collections.defaultdict):
|
class MetricsTracker(collections.defaultdict):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Passing the type 'int' to the base-class constructor
|
# Passing the type 'int' to the base-class constructor
|
||||||
@ -976,3 +1289,148 @@ def display_and_save_batch(
|
|||||||
y = sp.encode(supervisions["text"], out_type=int)
|
y = sp.encode(supervisions["text"], out_type=int)
|
||||||
num_tokens = sum(len(i) for i in y)
|
num_tokens = sum(len(i) for i in y)
|
||||||
logging.info(f"num tokens: {num_tokens}")
|
logging.info(f"num tokens: {num_tokens}")
|
||||||
|
|
||||||
|
|
||||||
|
def convert_timestamp(
|
||||||
|
frames: List[int],
|
||||||
|
subsampling_factor: int,
|
||||||
|
frame_shift_ms: float = 10,
|
||||||
|
) -> List[float]:
|
||||||
|
"""Convert frame numbers to time (in seconds) given subsampling factor
|
||||||
|
and frame shift (in milliseconds).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frames:
|
||||||
|
A list of frame numbers after subsampling.
|
||||||
|
subsampling_factor:
|
||||||
|
The subsampling factor of the model.
|
||||||
|
frame_shift_ms:
|
||||||
|
Frame shift in milliseconds between two contiguous frames.
|
||||||
|
Return:
|
||||||
|
Return the time in seconds corresponding to each given frame.
|
||||||
|
"""
|
||||||
|
frame_shift = frame_shift_ms / 1000.0
|
||||||
|
time = []
|
||||||
|
for f in frames:
|
||||||
|
time.append(f * subsampling_factor * frame_shift)
|
||||||
|
|
||||||
|
return time
|
||||||
|
|
||||||
|
|
||||||
|
def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]:
|
||||||
|
"""
|
||||||
|
Parse timestamp of each word.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens:
|
||||||
|
List of tokens.
|
||||||
|
timestamp:
|
||||||
|
List of timestamp of each token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of timestamp of each word.
|
||||||
|
"""
|
||||||
|
start_token = b"\xe2\x96\x81".decode() # '_'
|
||||||
|
assert len(tokens) == len(timestamp)
|
||||||
|
ans = []
|
||||||
|
for i in range(len(tokens)):
|
||||||
|
flag = False
|
||||||
|
if i == 0 or tokens[i].startswith(start_token):
|
||||||
|
flag = True
|
||||||
|
if len(tokens[i]) == 1 and tokens[i].startswith(start_token):
|
||||||
|
# tokens[i] == start_token
|
||||||
|
if i == len(tokens) - 1:
|
||||||
|
# it is the last token
|
||||||
|
flag = False
|
||||||
|
elif tokens[i + 1].startswith(start_token):
|
||||||
|
# the next token also starts with start_token
|
||||||
|
flag = False
|
||||||
|
if flag:
|
||||||
|
ans.append(timestamp[i])
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def parse_hyp_and_timestamp(
|
||||||
|
res: DecodingResults,
|
||||||
|
decoding_method: str,
|
||||||
|
sp: spm.SentencePieceProcessor,
|
||||||
|
subsampling_factor: int,
|
||||||
|
frame_shift_ms: float = 10,
|
||||||
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
|
) -> Tuple[List[List[str]], List[List[float]]]:
|
||||||
|
"""Parse hypothesis and timestamp.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
res:
|
||||||
|
A DecodingResults object.
|
||||||
|
decoding_method:
|
||||||
|
Possible values are:
|
||||||
|
- greedy_search
|
||||||
|
- beam_search
|
||||||
|
- modified_beam_search
|
||||||
|
- fast_beam_search
|
||||||
|
- fast_beam_search_nbest
|
||||||
|
- fast_beam_search_nbest_oracle
|
||||||
|
- fast_beam_search_nbest_LG
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
|
subsampling_factor:
|
||||||
|
The integer subsampling factor.
|
||||||
|
frame_shift_ms:
|
||||||
|
The float frame shift used for feature extraction.
|
||||||
|
word_table:
|
||||||
|
The word symbol table.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a list of hypothesis and timestamp.
|
||||||
|
"""
|
||||||
|
assert decoding_method in (
|
||||||
|
"greedy_search",
|
||||||
|
"beam_search",
|
||||||
|
"fast_beam_search",
|
||||||
|
"fast_beam_search_nbest",
|
||||||
|
"fast_beam_search_nbest_LG",
|
||||||
|
"fast_beam_search_nbest_oracle",
|
||||||
|
"modified_beam_search",
|
||||||
|
)
|
||||||
|
|
||||||
|
hyps = []
|
||||||
|
timestamps = []
|
||||||
|
|
||||||
|
N = len(res.tokens)
|
||||||
|
assert len(res.timestamps) == N
|
||||||
|
use_word_table = False
|
||||||
|
if decoding_method == "fast_beam_search_nbest_LG":
|
||||||
|
assert word_table is not None
|
||||||
|
use_word_table = True
|
||||||
|
|
||||||
|
for i in range(N):
|
||||||
|
tokens = sp.id_to_piece(res.tokens[i])
|
||||||
|
if use_word_table:
|
||||||
|
words = [word_table[i] for i in res.hyps[i]]
|
||||||
|
else:
|
||||||
|
words = sp.decode_pieces(tokens).split()
|
||||||
|
time = convert_timestamp(
|
||||||
|
res.timestamps[i], subsampling_factor, frame_shift_ms
|
||||||
|
)
|
||||||
|
time = parse_timestamp(tokens, time)
|
||||||
|
assert len(time) == len(words), (tokens, words)
|
||||||
|
|
||||||
|
hyps.append(words)
|
||||||
|
timestamps.append(time)
|
||||||
|
|
||||||
|
return hyps, timestamps
|
||||||
|
|
||||||
|
|
||||||
|
# `is_module_available` is copied from
|
||||||
|
# https://github.com/pytorch/audio/blob/6bad3a66a7a1c7cc05755e9ee5931b7391d2b94c/torchaudio/_internal/module_utils.py#L9
|
||||||
|
def is_module_available(*modules: str) -> bool:
|
||||||
|
r"""Returns if a top-level module with :attr:`name` exists *without**
|
||||||
|
importing it. This is generally safer than try-catch block around a
|
||||||
|
`import X`.
|
||||||
|
|
||||||
|
Note: "borrowed" from torchaudio:
|
||||||
|
"""
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
return all(importlib.util.find_spec(m) is not None for m in modules)
|
||||||
|
@ -23,4 +23,4 @@ multi_quantization
|
|||||||
|
|
||||||
onnx
|
onnx
|
||||||
onnxruntime
|
onnxruntime
|
||||||
onnx_graphsurgeon -i https://pypi.ngc.nvidia.com
|
kaldifst
|
||||||
|
@ -3,8 +3,4 @@ kaldialign
|
|||||||
sentencepiece>=0.1.96
|
sentencepiece>=0.1.96
|
||||||
tensorboard
|
tensorboard
|
||||||
typeguard
|
typeguard
|
||||||
multi_quantization
|
|
||||||
onnx
|
|
||||||
onnxruntime
|
|
||||||
--extra-index-url https://pypi.ngc.nvidia.com
|
|
||||||
dill
|
dill
|
||||||
|
74
test/test_ngram_lm.py
Executable file
74
test/test_ngram_lm.py
Executable file
@ -0,0 +1,74 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import graphviz
|
||||||
|
|
||||||
|
from icefall import is_module_available
|
||||||
|
|
||||||
|
if not is_module_available("kaldifst"):
|
||||||
|
raise ValueError("Please 'pip install kaldifst' first.")
|
||||||
|
|
||||||
|
import kaldifst
|
||||||
|
|
||||||
|
from icefall import NgramLm, NgramLmStateCost
|
||||||
|
|
||||||
|
|
||||||
|
def generate_fst(filename: str):
|
||||||
|
s = """
|
||||||
|
3 5 1 1 3.00464
|
||||||
|
3 0 3 0 5.75646
|
||||||
|
0 1 1 1 12.0533
|
||||||
|
0 2 2 2 7.95954
|
||||||
|
0 9.97787
|
||||||
|
1 4 2 2 3.35436
|
||||||
|
1 0 3 0 7.59853
|
||||||
|
2 0 3 0
|
||||||
|
4 2 3 0 7.43735
|
||||||
|
4 0.551239
|
||||||
|
5 4 2 2 0.804938
|
||||||
|
5 1 3 0 9.67086
|
||||||
|
"""
|
||||||
|
fst = kaldifst.compile(s=s, acceptor=False)
|
||||||
|
fst.write(filename)
|
||||||
|
fst_dot = kaldifst.draw(fst, acceptor=False, portrait=True)
|
||||||
|
source = graphviz.Source(fst_dot)
|
||||||
|
source.render(outfile=f"{filename}.svg")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
filename = "test.fst"
|
||||||
|
generate_fst(filename)
|
||||||
|
ngram_lm = NgramLm(filename, backoff_id=3, is_binary=True)
|
||||||
|
for label in [1, 2, 3, 4, 5]:
|
||||||
|
print("---label---", label)
|
||||||
|
p = ngram_lm.get_next_state_and_cost(state=5, label=label)
|
||||||
|
print(p)
|
||||||
|
print("---")
|
||||||
|
|
||||||
|
state_cost = NgramLmStateCost(ngram_lm)
|
||||||
|
s0 = state_cost.forward_one_step(1)
|
||||||
|
print(s0.state_cost)
|
||||||
|
|
||||||
|
s1 = s0.forward_one_step(2)
|
||||||
|
print(s1.state_cost)
|
||||||
|
|
||||||
|
s2 = s1.forward_one_step(2)
|
||||||
|
print(s2.state_cost)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
x
Reference in New Issue
Block a user