Merge branch 'master' into streaming5

This commit is contained in:
pkufool 2022-07-18 18:09:10 +08:00
commit d773b29db2
67 changed files with 9750 additions and 164 deletions

View File

@ -9,7 +9,7 @@ per-file-ignores =
egs/*/ASR/pruned_transducer_stateless*/*.py: E501, egs/*/ASR/pruned_transducer_stateless*/*.py: E501,
egs/*/ASR/*/optim.py: E501, egs/*/ASR/*/optim.py: E501,
egs/*/ASR/*/scaling.py: E501, egs/*/ASR/*/scaling.py: E501,
egs/librispeech/ASR/conv_emformer_transducer_stateless/*.py: E501, E203 egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203
# invalid escape sequence (cause by tex formular), W605 # invalid escape sequence (cause by tex formular), W605
icefall/utils.py: E501, W605 icefall/utils.py: E501, W605

View File

@ -27,7 +27,7 @@ soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav
pushd $repo/exp pushd $repo/exp
ln -s pretrained-epoch-29-avg-5-torch-1.10.pt pretrained.pt ln -s pretrained-epoch-29-avg-5-torch-1.10.0.pt pretrained.pt
popd popd
for sym in 1 2 3; do for sym in 1 2 3; do

View File

@ -37,7 +37,7 @@ for sym in 1 2 3; do
--nhead 8 \ --nhead 8 \
--encoder-dim 512 \ --encoder-dim 512 \
--decoder-dim 512 \ --decoder-dim 512 \
--joiner-dim 512 --joiner-dim 512 \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav
@ -82,6 +82,7 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" ==
./pruned_transducer_stateless5/decode.py \ ./pruned_transducer_stateless5/decode.py \
--decoding-method $method \ --decoding-method $method \
--use-averaged-model 0 \
--epoch 999 \ --epoch 999 \
--avg 1 \ --avg 1 \
--max-duration $max_duration \ --max-duration $max_duration \

View File

@ -0,0 +1,19 @@
# Introduction
This recipe includes some different ASR models trained with Aishell2.
[./RESULTS.md](./RESULTS.md) contains the latest results.
# Transducers
There are various folders containing the name `transducer` in this folder.
The following table lists the differences among them.
| | Encoder | Decoder | Comment |
|---------------------------------------|---------------------|--------------------|-----------------------------|
| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless5 in librispeech recipe |
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
We place an additional Conv1d layer right after the input embedding layer.

View File

@ -0,0 +1,89 @@
## Results
### Aishell2 char-based training results (Pruned Transducer 5)
#### 2022-07-11
Using the codes from this commit https://github.com/k2-fsa/icefall/pull/465.
When training with context size equals to 1, the WERs are
| | dev-ios | test-ios | comment |
|------------------------------------|-------|----------|----------------------------------|
| greedy search | 5.57 | 5.89 | --epoch 25, --avg 5, --max-duration 600 |
| modified beam search (beam size 4) | 5.32 | 5.56 | --epoch 25, --avg 5, --max-duration 600 |
| fast beam search (set as default) | 5.5 | 5.78 | --epoch 25, --avg 5, --max-duration 600 |
| fast beam search nbest | 5.46 | 5.74 | --epoch 25, --avg 5, --max-duration 600 |
| fast beam search oracle | 1.92 | 2.2 | --epoch 25, --avg 5, --max-duration 600 |
| fast beam search nbest LG | 5.59 | 5.93 | --epoch 25, --avg 5, --max-duration 600 |
The training command for reproducing is given below:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless5/train.py \
--world-size 4 \
--lang-dir data/lang_char \
--num-epochs 40 \
--start-epoch 1 \
--exp-dir /result \
--max-duration 300 \
--use-fp16 0 \
--num-encoder-layers 24 \
--dim-feedforward 1536 \
--nhead 8 \
--encoder-dim 384 \
--decoder-dim 512 \
--joiner-dim 512 \
--context-size 1
```
The decoding command is:
```bash
for method in greedy_search modified_beam_search \
fast_beam_search fast_beam_search_nbest \
fast_beam_search_nbest_oracle fast_beam_search_nbest_LG; do
./pruned_transducer_stateless5/decode.py \
--epoch 25 \
--avg 5 \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \
--decoding-method $method \
--max-sym-per-frame 1 \
--num-encoder-layers 24 \
--dim-feedforward 1536 \
--nhead 8 \
--encoder-dim 384 \
--decoder-dim 512 \
--joiner-dim 512 \
--context-size 1 \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5 \
--context-size 1 \
--use-averaged-model True
done
```
The tensorboard training log can be found at
https://tensorboard.dev/experiment/RXyX4QjQQVKjBS2eQ2Qajg/#scalars
A pre-trained model and decoding logs can be found at <https://huggingface.co/yuekai/icefall-asr-aishell2-pruned-transducer-stateless5-B-2022-07-12>
When training with context size equals to 2, the WERs are
| | dev-ios | test-ios | comment |
|------------------------------------|-------|----------|----------------------------------|
| greedy search | 5.47 | 5.81 | --epoch 25, --avg 5, --max-duration 600 |
| modified beam search (beam size 4) | 5.38 | 5.61 | --epoch 25, --avg 5, --max-duration 600 |
| fast beam search (set as default) | 5.36 | 5.61 | --epoch 25, --avg 5, --max-duration 600 |
| fast beam search nbest | 5.37 | 5.6 | --epoch 25, --avg 5, --max-duration 600 |
| fast beam search oracle | 2.04 | 2.2 | --epoch 25, --avg 5, --max-duration 600 |
| fast beam search nbest LG | 5.59 | 5.82 | --epoch 25, --avg 5, --max-duration 600 |
The tensorboard training log can be found at
https://tensorboard.dev/experiment/5AxJ8LHoSre8kDAuLp4L7Q/#scalars
A pre-trained model and decoding logs can be found at <https://huggingface.co/yuekai/icefall-asr-aishell2-pruned-transducer-stateless5-A-2022-07-12>

View File

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/compile_lg.py

View File

@ -0,0 +1,114 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the aishell2 dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_aishell2(num_mel_bins: int = 80):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
dataset_parts = (
"train",
"dev",
"test",
)
prefix = "aishell2"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if "train" in partition:
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-mel-bins",
type=int,
default=80,
help="""The number of mel bins for Fbank""",
)
return parser.parse_args()
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
compute_fbank_aishell2(num_mel_bins=args.num_mel_bins)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/compute_fbank_musan.py

View File

@ -0,0 +1,96 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file displays duration statistics of utterances in a manifest.
You can use the displayed value to choose minimum/maximum duration
to remove short and long utterances during the training.
See the function `remove_short_and_long_utt()` in transducer_stateless/train.py
for usage.
"""
from lhotse import load_manifest_lazy
def main():
paths = [
"./data/fbank/aishell2_cuts_train.jsonl.gz",
"./data/fbank/aishell2_cuts_dev.jsonl.gz",
"./data/fbank/aishell2_cuts_test.jsonl.gz",
]
for path in paths:
print(f"Starting display the statistics for {path}")
cuts = load_manifest_lazy(path)
cuts.describe()
if __name__ == "__main__":
main()
"""
Starting display the statistics for ./data/fbank/aishell2_cuts_train.jsonl.gz
Cuts count: 3026106
Total duration (hours): 3021.2
Speech duration (hours): 3021.2 (100.0%)
***
Duration statistics (seconds):
mean 3.6
std 1.5
min 0.3
25% 2.4
50% 3.3
75% 4.4
99% 8.2
99.5% 8.9
99.9% 10.6
max 21.5
Starting display the statistics for ./data/fbank/aishell2_cuts_dev.jsonl.gz
Cuts count: 2500
Total duration (hours): 2.0
Speech duration (hours): 2.0 (100.0%)
***
Duration statistics (seconds):
mean 2.9
std 1.0
min 1.1
25% 2.2
50% 2.7
75% 3.4
99% 6.3
99.5% 6.7
99.9% 7.8
max 9.4
Starting display the statistics for ./data/fbank/aishell2_cuts_test.jsonl.gz
Cuts count: 5000
Total duration (hours): 4.0
Speech duration (hours): 4.0 (100.0%)
***
Duration statistics (seconds):
mean 2.9
std 1.0
min 1.1
25% 2.2
50% 2.7
75% 3.3
99% 6.2
99.5% 6.6
99.9% 7.7
max 8.5
"""

View File

@ -0,0 +1 @@
../../../aidatatang_200zh/ASR/local/prepare_char.py

View File

@ -0,0 +1 @@
../../../wenetspeech/ASR/local/prepare_lang.py

View File

@ -0,0 +1 @@
../../../wenetspeech/ASR/local/prepare_words.py

View File

@ -0,0 +1 @@
../../../wenetspeech/ASR/local/text2segments.py

View File

@ -0,0 +1 @@
../../../aidatatang_200zh/ASR/local/text2token.py

181
egs/aishell2/ASR/prepare.sh Executable file
View File

@ -0,0 +1,181 @@
#!/usr/bin/env bash
set -eou pipefail
nj=30
stage=0
stop_stage=5
# We assume dl_dir (download dir) contains the following
# directories and files. If not, you need to apply aishell2 through
# their official website.
# https://www.aishelltech.com/aishell_2
#
# - $dl_dir/aishell2
#
#
# - $dl_dir/musan
# This directory contains the following directories downloaded from
# http://www.openslr.org/17/
#
# - music
# - noise
# - speech
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "stage 0: Download data"
# If you have pre-downloaded it to /path/to/aishell2,
# you can create a symlink
#
# ln -sfv /path/to/aishell2 $dl_dir/aishell2
#
# The directory structure is
# aishell2/
# |-- AISHELL-2
# | |-- iOS
# |-- data
# |-- wav
# |-- trans.txt
# |-- dev
# |-- wav
# |-- trans.txt
# |-- test
# |-- wav
# |-- trans.txt
# If you have pre-downloaded it to /path/to/musan,
# you can create a symlink
#
# ln -sfv /path/to/musan $dl_dir/musan
#
if [ ! -d $dl_dir/musan ]; then
lhotse download musan $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare aishell2 manifest"
# We assume that you have downloaded and unzip the aishell2 corpus
# to $dl_dir/aishell2
if [ ! -f data/manifests/.aishell2_manifests.done ]; then
mkdir -p data/manifests
lhotse prepare aishell2 $dl_dir/aishell2 data/manifests -j $nj
touch data/manifests/.aishell2_manifests.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Prepare musan manifest"
# We assume that you have downloaded the musan corpus
# to data/musan
if [ ! -f data/manifests/.musan_manifests.done ]; then
log "It may take 6 minutes"
mkdir -p data/manifests
lhotse prepare musan $dl_dir/musan data/manifests
touch data/manifests/.musan_manifests.done
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute fbank for aishell2"
if [ ! -f data/fbank/.aishell2.done ]; then
mkdir -p data/fbank
./local/compute_fbank_aishell2.py
touch data/fbank/.aishell2.done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for musan"
if [ ! -f data/fbank/.msuan.done ]; then
mkdir -p data/fbank
./local/compute_fbank_musan.py
touch data/fbank/.msuan.done
fi
fi
lang_char_dir=data/lang_char
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare char based lang"
mkdir -p $lang_char_dir
# Prepare text.
# Note: in Linux, you can install jq with the following command:
# 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
# 2. chmod +x ./jq
# 3. cp jq /usr/bin
if [ ! -f $lang_char_dir/text ]; then
gunzip -c data/manifests/aishell2_supervisions_train.jsonl.gz \
| jq '.text' | sed 's/"//g' \
| ./local/text2token.py -t "char" > $lang_char_dir/text
fi
# The implementation of chinese word segmentation for text,
# and it will take about 15 minutes.
# If you can't install paddle-tiny with python 3.8, please refer to
# https://github.com/fxsjy/jieba/issues/920
if [ ! -f $lang_char_dir/text_words_segmentation ]; then
python3 ./local/text2segments.py \
--input-file $lang_char_dir/text \
--output-file $lang_char_dir/text_words_segmentation
fi
cat $lang_char_dir/text_words_segmentation | sed 's/ /\n/g' \
| sort -u | sed '/^$/d' | uniq > $lang_char_dir/words_no_ids.txt
if [ ! -f $lang_char_dir/words.txt ]; then
python3 ./local/prepare_words.py \
--input-file $lang_char_dir/words_no_ids.txt \
--output-file $lang_char_dir/words.txt
fi
if [ ! -f $lang_char_dir/L_disambig.pt ]; then
python3 ./local/prepare_char.py
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare G"
# We assume you have install kaldilm, if not, please install
# it using: pip install kaldilm
if [ ! -f ${lang_char_dir}/3-gram.unpruned.arpa ]; then
./shared/make_kn_lm.py \
-ngram-order 3 \
-text $lang_char_dir/text_words_segmentation \
-lm $lang_char_dir/3-gram.unpruned.arpa
fi
mkdir -p data/lm
if [ ! -f data/lm/G_3_gram.fst.txt ]; then
# It is used in building LG
python3 -m kaldilm \
--read-symbol-table="$lang_char_dir/words.txt" \
--disambig-symbol='#0' \
--max-order=3 \
$lang_char_dir/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt
fi
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Compile LG"
./local/compile_lg.py --lang-dir $lang_char_dir
fi

View File

@ -0,0 +1,418 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class AiShell2AsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. ios, android, mic).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms.append(
CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
input_strategy=eval(self.args.input_strategy)(),
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to gen cuts from aishell2_cuts_train.jsonl.gz")
return load_manifest_lazy(
self.args.manifest_dir / "aishell2_cuts_train.jsonl.gz"
)
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to gen cuts from aishell2_cuts_dev.jsonl.gz")
return load_manifest_lazy(
self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz"
)
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to gen cuts from aishell2_cuts_test.jsonl.gz")
return load_manifest_lazy(
self.args.manifest_dir / "aishell2_cuts_test.jsonl.gz"
)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless5/conformer.py

View File

@ -0,0 +1,791 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless5/decode.py \
--epoch 25 \
--avg 5 \
--exp-dir ./pruned_transducer_stateless5/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search (not recommended)
./pruned_transducer_stateless5/decode.py \
--epoch 25 \
--avg 5 \
--exp-dir ./pruned_transducer_stateless5/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless5/decode.py \
--epoch 25 \
--avg 5 \
--exp-dir ./pruned_transducer_stateless5/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search (one best)
./pruned_transducer_stateless5/decode.py \
--epoch 25 \
--avg 5 \
--exp-dir ./pruned_transducer_stateless5/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
(5) fast beam search (nbest)
./pruned_transducer_stateless5/decode.py \
--epoch 25 \
--avg 5 \
--exp-dir ./pruned_transducer_stateless5/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless5/decode.py \
--epoch 25 \
--avg 5 \
--exp-dir ./pruned_transducer_stateless5/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./pruned_transducer_stateless5/decode.py \
--epoch 25 \
--avg 5 \
--exp-dir ./pruned_transducer_stateless5/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import AiShell2AsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless5/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_char",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
sentence = "".join([lexicon.word_table[i] for i in hyp])
hyps.append(list(sentence))
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=graph_compiler.texts_to_ids(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append([lexicon.token_table[idx] for idx in hyp])
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch(
params=params,
model=model,
lexicon=lexicon,
graph_compiler=graph_compiler,
decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
this_batch.append((ref_text, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
AiShell2AsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.unk_id = lexicon.token_table["<unk>"]
params.vocab_size = max(lexicon.tokens) + 1
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
aishell2 = AiShell2AsrDataModule(args)
valid_cuts = aishell2.valid_cuts()
test_cuts = aishell2.test_cuts()
# use ios sets for dev and test
dev_dl = aishell2.valid_dataloaders(valid_cuts)
test_dl = aishell2.test_dataloaders(test_cuts)
test_sets = ["dev", "test"]
test_dl = [dev_dl, test_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
lexicon=lexicon,
graph_compiler=graph_compiler,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py

View File

@ -0,0 +1,274 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless5/exp \
--lang-dir data/lang_char
--epoch 25 \
--avg 5
It will generate a file exp_dir/pretrained.pt
To use the generated file with `pruned_transducer_stateless5/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/aishell2/ASR
./pruned_transducer_stateless5/decode.py \
--exp-dir ./pruned_transducer_stateless5/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--lang-dir data/lang_char
"""
import argparse
import logging
from pathlib import Path
import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=False,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless5/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="The lang dir",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
add_model_arguments(parser)
return parser
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.unk_id = lexicon.token_table["<unk>"]
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/optim.py

View File

@ -0,0 +1,342 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--lang-dir ./data/lang_char \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) modified beam search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--lang-dir ./data/lang_char \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(3) fast beam search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--lang-dir ./data/lang_char \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./pruned_transducer_stateless5/exp/epoch-xx.pt`.
Note: ./pruned_transducer_stateless5/exp/pretrained.pt is generated by
./pruned_transducer_stateless5/export.py
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.lexicon import Lexicon
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--lang-dir",
type=str,
help="""Path to lang.
""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.unk_id = lexicon.token_table["<unk>"]
params.vocab_size = max(lexicon.tokens) + 1
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0)
hyps = []
msg = f"Using {params.method}"
if params.method == "beam_search":
msg += f" with beam size {params.beam_size}"
logging.info(msg)
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append([lexicon.token_table[idx] for idx in hyp])
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = "".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py

File diff suppressed because it is too large Load Diff

1
egs/aishell2/ASR/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared/

View File

@ -23,8 +23,8 @@ The following table lists the differences among them.
| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + more layers + random combiner| | `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + more layers + random combiner|
| `pruned_transducer_stateless6` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + distillation with hubert| | `pruned_transducer_stateless6` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + distillation with hubert|
| `pruned_stateless_emformer_rnnt2` | Emformer(from torchaudio) | Embedding + Conv1d | Using Emformer from torchaudio for streaming ASR| | `pruned_stateless_emformer_rnnt2` | Emformer(from torchaudio) | Embedding + Conv1d | Using Emformer from torchaudio for streaming ASR|
| `conv_emformer_transducer_stateless` | Emformer | Embedding + Conv1d | Using Emformer augmented with convolution for streaming ASR + mechanisms in reworked model | | `conv_emformer_transducer_stateless` | ConvEmformer | Embedding + Conv1d | Using ConvEmformer for streaming ASR + mechanisms in reworked model |
| `conv_emformer_transducer_stateless2` | ConvEmformer | Embedding + Conv1d | Using ConvEmformer with simplified memory for streaming ASR + mechanisms in reworked model |
The decoder in `transducer_stateless` is modified from the paper The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).

View File

@ -1,5 +1,317 @@
## Results ## Results
### LibriSpeech BPE training results (Pruned Stateless Conv-Emformer RNN-T 2)
[conv_emformer_transducer_stateless2](./conv_emformer_transducer_stateless2)
It implements [Emformer](https://arxiv.org/abs/2010.10759) augmented with convolution module and simplified memory bank for streaming ASR.
It is modified from [torchaudio](https://github.com/pytorch/audio).
See <https://github.com/k2-fsa/icefall/pull/440> for more details.
#### With lower latency setup, training on full librispeech
In this model, the lengths of chunk and right context are 32 frames (i.e., 0.32s) and 8 frames (i.e., 0.08s), respectively.
The WERs are:
| | test-clean | test-other | comment | decoding mode |
|-------------------------------------|------------|------------|----------------------|----------------------|
| greedy search (max sym per frame 1) | 3.5 | 9.09 | --epoch 30 --avg 10 | simulated streaming |
| greedy search (max sym per frame 1) | 3.57 | 9.1 | --epoch 30 --avg 10 | streaming |
| fast beam search | 3.5 | 8.91 | --epoch 30 --avg 10 | simulated streaming |
| fast beam search | 3.54 | 8.91 | --epoch 30 --avg 10 | streaming |
| modified beam search | 3.43 | 8.86 | --epoch 30 --avg 10 | simulated streaming |
| modified beam search | 3.48 | 8.88 | --epoch 30 --avg 10 | streaming |
The training command is:
```bash
./conv_emformer_transducer_stateless2/train.py \
--world-size 6 \
--num-epochs 30 \
--start-epoch 1 \
--exp-dir conv_emformer_transducer_stateless2/exp \
--full-libri 1 \
--max-duration 280 \
--master-port 12321 \
--num-encoder-layers 12 \
--chunk-length 32 \
--cnn-module-kernel 31 \
--left-context-length 32 \
--right-context-length 8 \
--memory-size 32
```
The tensorboard log can be found at
<https://tensorboard.dev/experiment/W5MpxekiQLSPyM4fe5hbKg/>
The simulated streaming decoding command using greedy search is:
```bash
./conv_emformer_transducer_stateless2/decode.py \
--epoch 30 \
--avg 10 \
--exp-dir conv_emformer_transducer_stateless2/exp \
--max-duration 300 \
--num-encoder-layers 12 \
--chunk-length 32 \
--cnn-module-kernel 31 \
--left-context-length 32 \
--right-context-length 8 \
--memory-size 32 \
--decoding-method greedy_search \
--use-averaged-model True
```
The simulated streaming decoding command using fast beam search is:
```bash
./conv_emformer_transducer_stateless2/decode.py \
--epoch 30 \
--avg 10 \
--exp-dir conv_emformer_transducer_stateless2/exp \
--max-duration 300 \
--num-encoder-layers 12 \
--chunk-length 32 \
--cnn-module-kernel 31 \
--left-context-length 32 \
--right-context-length 8 \
--memory-size 32 \
--decoding-method fast_beam_search \
--use-averaged-model True \
--beam 4 \
--max-contexts 4 \
--max-states 8
```
The simulated streaming decoding command using modified beam search is:
```bash
./conv_emformer_transducer_stateless2/decode.py \
--epoch 30 \
--avg 10 \
--exp-dir conv_emformer_transducer_stateless2/exp \
--max-duration 300 \
--num-encoder-layers 12 \
--chunk-length 32 \
--cnn-module-kernel 31 \
--left-context-length 32 \
--right-context-length 8 \
--memory-size 32 \
--decoding-method modified_beam_search \
--use-averaged-model True \
--beam-size 4
```
The streaming decoding command using greedy search is:
```bash
./conv_emformer_transducer_stateless2/streaming_decode.py \
--epoch 30 \
--avg 10 \
--exp-dir conv_emformer_transducer_stateless2/exp \
--num-decode-streams 2000 \
--num-encoder-layers 12 \
--chunk-length 32 \
--cnn-module-kernel 31 \
--left-context-length 32 \
--right-context-length 8 \
--memory-size 32 \
--decoding-method greedy_search \
--use-averaged-model True
```
The streaming decoding command using fast beam search is:
```bash
./conv_emformer_transducer_stateless2/streaming_decode.py \
--epoch 30 \
--avg 10 \
--exp-dir conv_emformer_transducer_stateless2/exp \
--num-decode-streams 2000 \
--num-encoder-layers 12 \
--chunk-length 32 \
--cnn-module-kernel 31 \
--left-context-length 32 \
--right-context-length 8 \
--memory-size 32 \
--decoding-method fast_beam_search \
--use-averaged-model True \
--beam 4 \
--max-contexts 4 \
--max-states 8
```
The streaming decoding command using modified beam search is:
```bash
./conv_emformer_transducer_stateless2/streaming_decode.py \
--epoch 30 \
--avg 10 \
--exp-dir conv_emformer_transducer_stateless2/exp \
--num-decode-streams 2000 \
--num-encoder-layers 12 \
--chunk-length 32 \
--cnn-module-kernel 31 \
--left-context-length 32 \
--right-context-length 8 \
--memory-size 32 \
--decoding-method modified_beam_search \
--use-averaged-model True \
--beam-size 4
```
Pretrained models, training logs, decoding logs, and decoding results
are available at
<https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05>
#### With higher latency setup, training on full librispeech
In this model, the lengths of chunk and right context are 64 frames (i.e., 0.64s) and 16 frames (i.e., 0.16s), respectively.
The WERs are:
| | test-clean | test-other | comment | decoding mode |
|-------------------------------------|------------|------------|----------------------|----------------------|
| greedy search (max sym per frame 1) | 3.3 | 8.71 | --epoch 30 --avg 10 | simulated streaming |
| greedy search (max sym per frame 1) | 3.35 | 8.65 | --epoch 30 --avg 10 | streaming |
| fast beam search | 3.27 | 8.58 | --epoch 30 --avg 10 | simulated streaming |
| fast beam search | 3.31 | 8.48 | --epoch 30 --avg 10 | streaming |
| modified beam search | 3.26 | 8.56 | --epoch 30 --avg 10 | simulated streaming |
| modified beam search | 3.29 | 8.47 | --epoch 30 --avg 10 | streaming |
The training command is:
```bash
./conv_emformer_transducer_stateless2/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--exp-dir conv_emformer_transducer_stateless2/exp \
--full-libri 1 \
--max-duration 280 \
--master-port 12321 \
--num-encoder-layers 12 \
--chunk-length 64 \
--cnn-module-kernel 31 \
--left-context-length 64 \
--right-context-length 16 \
--memory-size 32
```
The tensorboard log can be found at
<https://tensorboard.dev/experiment/eRx6XwbOQhGlywgD8lWBjw/>
The simulated streaming decoding command using greedy search is:
```bash
./conv_emformer_transducer_stateless2/decode.py \
--epoch 30 \
--avg 10 \
--exp-dir conv_emformer_transducer_stateless2/exp \
--max-duration 300 \
--num-encoder-layers 12 \
--chunk-length 64 \
--cnn-module-kernel 31 \
--left-context-length 64 \
--right-context-length 16 \
--memory-size 32 \
--decoding-method greedy_search \
--use-averaged-model True
```
The simulated streaming decoding command using fast beam search is:
```bash
./conv_emformer_transducer_stateless2/decode.py \
--epoch 30 \
--avg 10 \
--exp-dir conv_emformer_transducer_stateless2/exp \
--max-duration 300 \
--num-encoder-layers 12 \
--chunk-length 64 \
--cnn-module-kernel 31 \
--left-context-length 64 \
--right-context-length 16 \
--memory-size 32 \
--decoding-method fast_beam_search \
--use-averaged-model True \
--beam 4 \
--max-contexts 4 \
--max-states 8
```
The simulated streaming decoding command using modified beam search is:
```bash
./conv_emformer_transducer_stateless2/decode.py \
--epoch 30 \
--avg 10 \
--exp-dir conv_emformer_transducer_stateless2/exp \
--max-duration 300 \
--num-encoder-layers 12 \
--chunk-length 64 \
--cnn-module-kernel 31 \
--left-context-length 64 \
--right-context-length 16 \
--memory-size 32 \
--decoding-method modified_beam_search \
--use-averaged-model True \
--beam-size 4
```
The streaming decoding command using greedy search is:
```bash
./conv_emformer_transducer_stateless2/streaming_decode.py \
--epoch 30 \
--avg 10 \
--exp-dir conv_emformer_transducer_stateless2/exp \
--num-decode-streams 2000 \
--num-encoder-layers 12 \
--chunk-length 64 \
--cnn-module-kernel 31 \
--left-context-length 64 \
--right-context-length 16 \
--memory-size 32 \
--decoding-method greedy_search \
--use-averaged-model True
```
The streaming decoding command using fast beam search is:
```bash
./conv_emformer_transducer_stateless2/streaming_decode.py \
--epoch 30 \
--avg 10 \
--exp-dir conv_emformer_transducer_stateless2/exp \
--num-decode-streams 2000 \
--num-encoder-layers 12 \
--chunk-length 64 \
--cnn-module-kernel 31 \
--left-context-length 64 \
--right-context-length 16 \
--memory-size 32 \
--decoding-method fast_beam_search \
--use-averaged-model True \
--beam 4 \
--max-contexts 4 \
--max-states 8
```
The streaming decoding command using modified beam search is:
```bash
./conv_emformer_transducer_stateless2/streaming_decode.py \
--epoch 30 \
--avg 10 \
--exp-dir conv_emformer_transducer_stateless2/exp \
--num-decode-streams 2000 \
--num-encoder-layers 12 \
--chunk-length 64 \
--cnn-module-kernel 31 \
--left-context-length 64 \
--right-context-length 16 \
--memory-size 32 \
--decoding-method modified_beam_search \
--use-averaged-model True \
--beam-size 4
```
Pretrained models, training logs, decoding logs, and decoding results
are available at
<https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-larger-latency-2022-07-06>
### LibriSpeech BPE training results (Pruned Stateless Streaming Conformer RNN-T) ### LibriSpeech BPE training results (Pruned Stateless Streaming Conformer RNN-T)
#### [pruned_transducer_stateless](./pruned_transducer_stateless) #### [pruned_transducer_stateless](./pruned_transducer_stateless)
@ -556,9 +868,9 @@ Number of model parameters 118129516 (i.e, 118.13 M).
| | test-clean | test-other | comment | | | test-clean | test-other | comment |
|-------------------------------------|------------|------------|----------------------------------------| |-------------------------------------|------------|------------|----------------------------------------|
| greedy search (max sym per frame 1) | 2.39 | 5.57 | --epoch 39 --avg 7 --max-duration 600 | | greedy search (max sym per frame 1) | 2.43 | 5.72 | --epoch 30 --avg 10 --max-duration 600 |
| modified beam search | 2.35 | 5.50 | --epoch 39 --avg 7 --max-duration 600 | | modified beam search | 2.43 | 5.69 | --epoch 30 --avg 10 --max-duration 600 |
| fast beam search | 2.38 | 5.50 | --epoch 39 --avg 7 --max-duration 600 | | fast beam search | 2.43 | 5.67 | --epoch 30 --avg 10 --max-duration 600 |
The training commands are: The training commands are:
@ -567,8 +879,8 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
./pruned_transducer_stateless5/train.py \ ./pruned_transducer_stateless5/train.py \
--world-size 8 \ --world-size 8 \
--num-epochs 40 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 1 \
--full-libri 1 \ --full-libri 1 \
--exp-dir pruned_transducer_stateless5/exp-L \ --exp-dir pruned_transducer_stateless5/exp-L \
--max-duration 300 \ --max-duration 300 \
@ -582,15 +894,15 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
``` ```
The tensorboard log can be found at The tensorboard log can be found at
<https://tensorboard.dev/experiment/Zq0h3KpnQ2igWbeR4U82Pw/> <https://tensorboard.dev/experiment/aWzDj5swSE2VmcOYgoe3vQ>
The decoding commands are: The decoding commands are:
```bash ```bash
for method in greedy_search modified_beam_search fast_beam_search; do for method in greedy_search modified_beam_search fast_beam_search; do
./pruned_transducer_stateless5/decode.py \ ./pruned_transducer_stateless5/decode.py \
--epoch 39 \ --epoch 30 \
--avg 7 \ --avg 10 \
--exp-dir ./pruned_transducer_stateless5/exp-L \ --exp-dir ./pruned_transducer_stateless5/exp-L \
--max-duration 600 \ --max-duration 600 \
--decoding-method $method \ --decoding-method $method \
@ -600,13 +912,14 @@ for method in greedy_search modified_beam_search fast_beam_search; do
--nhead 8 \ --nhead 8 \
--encoder-dim 512 \ --encoder-dim 512 \
--decoder-dim 512 \ --decoder-dim 512 \
--joiner-dim 512 --joiner-dim 512 \
--use-averaged-model True
done done
``` ```
You can find a pretrained model, training logs, decoding logs, and decoding You can find a pretrained model, training logs, decoding logs, and decoding
results at: results at:
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13> <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-2022-07-07>
#### Medium #### Medium
@ -615,9 +928,9 @@ Number of model parameters 30896748 (i.e, 30.9 M).
| | test-clean | test-other | comment | | | test-clean | test-other | comment |
|-------------------------------------|------------|------------|-----------------------------------------| |-------------------------------------|------------|------------|-----------------------------------------|
| greedy search (max sym per frame 1) | 2.88 | 6.69 | --epoch 39 --avg 17 --max-duration 600 | | greedy search (max sym per frame 1) | 2.87 | 6.92 | --epoch 30 --avg 10 --max-duration 600 |
| modified beam search | 2.83 | 6.59 | --epoch 39 --avg 17 --max-duration 600 | | modified beam search | 2.83 | 6.75 | --epoch 30 --avg 10 --max-duration 600 |
| fast beam search | 2.83 | 6.61 | --epoch 39 --avg 17 --max-duration 600 | | fast beam search | 2.81 | 6.76 | --epoch 30 --avg 10 --max-duration 600 |
The training commands are: The training commands are:
@ -626,8 +939,8 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
./pruned_transducer_stateless5/train.py \ ./pruned_transducer_stateless5/train.py \
--world-size 8 \ --world-size 8 \
--num-epochs 40 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 1 \
--full-libri 1 \ --full-libri 1 \
--exp-dir pruned_transducer_stateless5/exp-M \ --exp-dir pruned_transducer_stateless5/exp-M \
--max-duration 300 \ --max-duration 300 \
@ -641,15 +954,15 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
``` ```
The tensorboard log can be found at The tensorboard log can be found at
<https://tensorboard.dev/experiment/bOQvULPsQ1iL7xpdI0VbXw/> <https://tensorboard.dev/experiment/04xtWUKPRmebSnpzN1GMHQ>
The decoding commands are: The decoding commands are:
```bash ```bash
for method in greedy_search modified_beam_search fast_beam_search; do for method in greedy_search modified_beam_search fast_beam_search; do
./pruned_transducer_stateless5/decode.py \ ./pruned_transducer_stateless5/decode.py \
--epoch 39 \ --epoch 30 \
--avg 17 \ --avg 10 \
--exp-dir ./pruned_transducer_stateless5/exp-M \ --exp-dir ./pruned_transducer_stateless5/exp-M \
--max-duration 600 \ --max-duration 600 \
--decoding-method $method \ --decoding-method $method \
@ -659,13 +972,14 @@ for method in greedy_search modified_beam_search fast_beam_search; do
--nhead 4 \ --nhead 4 \
--encoder-dim 256 \ --encoder-dim 256 \
--decoder-dim 512 \ --decoder-dim 512 \
--joiner-dim 512 --joiner-dim 512 \
--use-averaged-model True
done done
``` ```
You can find a pretrained model, training logs, decoding logs, and decoding You can find a pretrained model, training logs, decoding logs, and decoding
results at: results at:
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-M-2022-05-13> <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-M-2022-07-07>
#### Baseline-2 #### Baseline-2
@ -675,19 +989,19 @@ layers (24 v.s 12) but a narrower model (1536 feedforward dim and 384 encoder di
| | test-clean | test-other | comment | | | test-clean | test-other | comment |
|-------------------------------------|------------|------------|-----------------------------------------| |-------------------------------------|------------|------------|-----------------------------------------|
| greedy search (max sym per frame 1) | 2.41 | 5.70 | --epoch 31 --avg 17 --max-duration 600 | | greedy search (max sym per frame 1) | 2.54 | 5.72 | --epoch 30 --avg 10 --max-duration 600 |
| modified beam search | 2.41 | 5.69 | --epoch 31 --avg 17 --max-duration 600 | | modified beam search | 2.47 | 5.71 | --epoch 30 --avg 10 --max-duration 600 |
| fast beam search | 2.41 | 5.69 | --epoch 31 --avg 17 --max-duration 600 | | fast beam search | 2.5 | 5.72 | --epoch 30 --avg 10 --max-duration 600 |
```bash ```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
./pruned_transducer_stateless5/train.py \ ./pruned_transducer_stateless5/train.py \
--world-size 8 \ --world-size 8 \
--num-epochs 40 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 1 \
--full-libri 1 \ --full-libri 1 \
--exp-dir pruned_transducer_stateless5/exp \ --exp-dir pruned_transducer_stateless5/exp-B \
--max-duration 300 \ --max-duration 300 \
--use-fp16 0 \ --use-fp16 0 \
--num-encoder-layers 24 \ --num-encoder-layers 24 \
@ -699,19 +1013,16 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
``` ```
The tensorboard log can be found at The tensorboard log can be found at
<https://tensorboard.dev/experiment/73oY9U1mQiq0tbbcovZplw/> <https://tensorboard.dev/experiment/foVHNyqiRi2LhybmRUOAyg>
**Caution**: The training script is updated so that epochs are counted from 1
after the training.
The decoding commands are: The decoding commands are:
```bash ```bash
for method in greedy_search modified_beam_search fast_beam_search; do for method in greedy_search modified_beam_search fast_beam_search; do
./pruned_transducer_stateless5/decode.py \ ./pruned_transducer_stateless5/decode.py \
--epoch 31 \ --epoch 30 \
--avg 17 \ --avg 10 \
--exp-dir ./pruned_transducer_stateless5/exp-M \ --exp-dir ./pruned_transducer_stateless5/exp-B \
--max-duration 600 \ --max-duration 600 \
--decoding-method $method \ --decoding-method $method \
--max-sym-per-frame 1 \ --max-sym-per-frame 1 \
@ -720,13 +1031,14 @@ for method in greedy_search modified_beam_search fast_beam_search; do
--nhead 8 \ --nhead 8 \
--encoder-dim 384 \ --encoder-dim 384 \
--decoder-dim 512 \ --decoder-dim 512 \
--joiner-dim 512 --joiner-dim 512 \
--use-averaged-model True
done done
``` ```
You can find a pretrained model, training logs, decoding logs, and decoding You can find a pretrained model, training logs, decoding logs, and decoding
results at: results at:
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-narrower-2022-05-13> <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-B-2022-07-07>
### LibriSpeech BPE training results (Pruned Stateless Transducer 4) ### LibriSpeech BPE training results (Pruned Stateless Transducer 4)

View File

@ -277,10 +277,10 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
feature_lens += params.right_context_length feature_lens += params.chunk_length
feature = torch.nn.functional.pad( feature = torch.nn.functional.pad(
feature, feature,
pad=(0, 0, 0, params.right_context_length), pad=(0, 0, 0, params.chunk_length),
value=LOG_EPS, value=LOG_EPS,
) )

View File

@ -1141,8 +1141,8 @@ class EmformerEncoderLayer(nn.Module):
- output utterance, with shape (U, B, D); - output utterance, with shape (U, B, D);
- output right_context, with shape (R, B, D); - output right_context, with shape (R, B, D);
- output memory, with shape (1, B, D) or (0, B, D). - output memory, with shape (1, B, D) or (0, B, D).
- output state. - updated attention cache.
- updated conv_cache. - updated convolution cache.
""" """
R = right_context.size(0) R = right_context.size(0)
src = torch.cat([right_context, utterance]) src = torch.cat([right_context, utterance])
@ -1252,6 +1252,11 @@ class EmformerEncoder(nn.Module):
): ):
super().__init__() super().__init__()
assert (
chunk_length - 1
) & chunk_length == 0, "chunk_length should be a power of 2."
self.shift = int(math.log(chunk_length, 2))
self.use_memory = memory_size > 0 self.use_memory = memory_size > 0
self.init_memory_op = nn.AvgPool1d( self.init_memory_op = nn.AvgPool1d(
kernel_size=chunk_length, kernel_size=chunk_length,
@ -1525,7 +1530,6 @@ class EmformerEncoder(nn.Module):
right_context at the end. right_context at the end.
states (List[torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]]: # noqa states (List[torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]]: # noqa
Cached states containing: Cached states containing:
- past_lens: number of past frames for each sample in batch
- attn_caches: attention states from preceding chunk's computation, - attn_caches: attention states from preceding chunk's computation,
where each element corresponds to each emformer layer where each element corresponds to each emformer layer
- conv_caches: left context for causal convolution, where each - conv_caches: left context for causal convolution, where each
@ -1580,13 +1584,15 @@ class EmformerEncoder(nn.Module):
# calcualte padding mask to mask out initial zero caches # calcualte padding mask to mask out initial zero caches
chunk_mask = make_pad_mask(output_lengths).to(x.device) chunk_mask = make_pad_mask(output_lengths).to(x.device)
memory_mask = ( memory_mask = (
torch.div( (
num_processed_frames, self.chunk_length, rounding_mode="floor" (num_processed_frames >> self.shift).view(x.size(1), 1)
).view(x.size(1), 1) <= torch.arange(self.memory_size, device=x.device).expand(
<= torch.arange(self.memory_size, device=x.device).expand( x.size(1), self.memory_size
x.size(1), self.memory_size )
) ).flip(1)
).flip(1) if self.use_memory
else torch.empty(0).to(dtype=torch.bool, device=x.device)
)
left_context_mask = ( left_context_mask = (
num_processed_frames.view(x.size(1), 1) num_processed_frames.view(x.size(1), 1)
<= torch.arange(self.left_context_length, device=x.device).expand( <= torch.arange(self.left_context_length, device=x.device).expand(
@ -1631,6 +1637,31 @@ class EmformerEncoder(nn.Module):
) )
return output, output_lengths, output_states return output, output_lengths, output_states
@torch.jit.export
def init_states(self, device: torch.device = torch.device("cpu")):
"""Create initial states."""
attn_caches = [
[
torch.zeros(self.memory_size, self.d_model, device=device),
torch.zeros(
self.left_context_length, self.d_model, device=device
),
torch.zeros(
self.left_context_length, self.d_model, device=device
),
]
for _ in range(self.num_encoder_layers)
]
conv_caches = [
torch.zeros(self.d_model, self.cnn_module_kernel - 1, device=device)
for _ in range(self.num_encoder_layers)
]
states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]] = (
attn_caches,
conv_caches,
)
return states
class Emformer(EncoderInterface): class Emformer(EncoderInterface):
def __init__( def __init__(
@ -1655,6 +1686,7 @@ class Emformer(EncoderInterface):
self.subsampling_factor = subsampling_factor self.subsampling_factor = subsampling_factor
self.right_context_length = right_context_length self.right_context_length = right_context_length
self.chunk_length = chunk_length
if subsampling_factor != 4: if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.") raise NotImplementedError("Support only 'subsampling_factor=4'.")
if chunk_length % subsampling_factor != 0: if chunk_length % subsampling_factor != 0:
@ -1803,6 +1835,11 @@ class Emformer(EncoderInterface):
return output, output_lengths, output_states return output, output_lengths, output_states
@torch.jit.export
def init_states(self, device: torch.device = torch.device("cpu")):
"""Create initial states."""
return self.encoder.init_states(device)
class Conv2dSubsampling(nn.Module): class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length). """Convolutional 2D subsampling (to 1/4 length).

View File

@ -43,15 +43,12 @@ class Stream(object):
device: device:
The device to run this stream. The device to run this stream.
""" """
self.device = device
self.LOG_EPS = LOG_EPS self.LOG_EPS = LOG_EPS
# Containing attention caches and convolution caches # Containing attention caches and convolution caches
self.states: Optional[ self.states: Optional[
Tuple[List[List[torch.Tensor]], List[torch.Tensor]] Tuple[List[List[torch.Tensor]], List[torch.Tensor]]
] = None ] = None
# Initailize zero states.
self.init_states(params)
# It uses different attributes for different decoding methods. # It uses different attributes for different decoding methods.
self.context_size = params.context_size self.context_size = params.context_size
@ -107,34 +104,11 @@ class Stream(object):
def set_ground_truth(self, ground_truth: str) -> None: def set_ground_truth(self, ground_truth: str) -> None:
self.ground_truth = ground_truth self.ground_truth = ground_truth
def init_states(self, params: AttributeDict) -> None: def set_states(
attn_caches = [ self, states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]]
[ ) -> None:
torch.zeros( """Set states."""
params.memory_size, params.encoder_dim, device=self.device self.states = states
),
torch.zeros(
params.left_context_length // params.subsampling_factor,
params.encoder_dim,
device=self.device,
),
torch.zeros(
params.left_context_length // params.subsampling_factor,
params.encoder_dim,
device=self.device,
),
]
for _ in range(params.num_encoder_layers)
]
conv_caches = [
torch.zeros(
params.encoder_dim,
params.cnn_module_kernel - 1,
device=self.device,
)
for _ in range(params.num_encoder_layers)
]
self.states = (attn_caches, conv_caches)
def get_feature_chunk(self) -> torch.Tensor: def get_feature_chunk(self) -> torch.Tensor:
"""Get a chunk of feature frames. """Get a chunk of feature frames.

View File

@ -683,6 +683,8 @@ def decode_dataset(
LOG_EPS=LOG_EPSILON, LOG_EPS=LOG_EPSILON,
) )
stream.set_states(model.encoder.init_states(device))
audio: np.ndarray = cut.load_audio() audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples) # audio.shape: (1, num_samples)
assert len(audio.shape) == 2 assert len(audio.shape) == 2

View File

@ -28,7 +28,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--start-epoch 1 \ --start-epoch 1 \
--exp-dir conv_emformer_transducer_stateless/exp \ --exp-dir conv_emformer_transducer_stateless/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 300 \ --max-duration 280 \
--master-port 12321 \ --master-port 12321 \
--num-encoder-layers 12 \ --num-encoder-layers 12 \
--chunk-length 32 \ --chunk-length 32 \

View File

@ -0,0 +1 @@
../conv_emformer_transducer_stateless/asr_datamodule.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/beam_search.py

View File

@ -0,0 +1,657 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./conv_emformer_transducer_stateless2/decode.py \
--epoch 30 \
--avg 10 \
--exp-dir conv_emformer_transducer_stateless2/exp \
--max-duration 300 \
--num-encoder-layers 12 \
--chunk-length 32 \
--cnn-module-kernel 31 \
--left-context-length 32 \
--right-context-length 8 \
--memory-size 32 \
--decoding-method greedy_search \
--use-averaged-model True
(2) modified beam search
./conv_emformer_transducer_stateless2/decode.py \
--epoch 30 \
--avg 10 \
--exp-dir conv_emformer_transducer_stateless2/exp \
--max-duration 300 \
--num-encoder-layers 12 \
--chunk-length 32 \
--cnn-module-kernel 31 \
--left-context-length 32 \
--right-context-length 8 \
--memory-size 32 \
--decoding-method modified_beam_search \
--use-averaged-model True \
--beam-size 4
(3) fast beam search
./conv_emformer_transducer_stateless2/decode.py \
--epoch 30 \
--avg 10 \
--exp-dir conv_emformer_transducer_stateless2/exp \
--max-duration 300 \
--num-encoder-layers 12 \
--chunk-length 32 \
--cnn-module-kernel 31 \
--left-context-length 32 \
--right-context-length 8 \
--memory-size 32 \
--decoding-method fast_beam_search \
--use-averaged-model True \
--beam 4 \
--max-contexts 4 \
--max-states 8
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=10,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless4/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
feature_lens += params.chunk_length
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.chunk_length),
value=LOG_EPS,
)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 100
else:
log_interval = 2
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0
start = params.epoch - params.avg
assert start >= 1
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../conv_emformer_transducer_stateless/decoder.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../conv_emformer_transducer_stateless/encoder_interface.py

View File

@ -0,0 +1,287 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
./conv_emformer_transducer_stateless2/export.py \
--exp-dir ./conv_emformer_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 30 \
--avg 10 \
--use-averaged-model=True \
--num-encoder-layers 12 \
--chunk-length 32 \
--cnn-module-kernel 31 \
--left-context-length 32 \
--right-context-length 8 \
--memory-size 32 \
--jit False
It will generate a file exp_dir/pretrained.pt
To use the generated file with `conv_emformer_transducer_stateless2/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./conv_emformer_transducer_stateless2/decode.py \
--exp-dir ./conv_emformer_transducer_stateless2/exp \
--epoch 9999 \
--avg 1 \
--max-duration 100 \
--bpe-model data/lang_bpe_500/bpe.model \
--use-averaged-model=False \
--num-encoder-layers 12 \
--chunk-length 32 \
--cnn-module-kernel 31 \
--left-context-length 32 \
--right-context-length 8 \
--memory-size 32
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
add_model_arguments(parser)
return parser
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.eval()
if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../conv_emformer_transducer_stateless/joiner.py

View File

@ -0,0 +1 @@
../conv_emformer_transducer_stateless/model.py

View File

@ -0,0 +1 @@
../conv_emformer_transducer_stateless/optim.py

View File

@ -0,0 +1 @@
../conv_emformer_transducer_stateless/scaling.py

View File

@ -0,0 +1 @@
../conv_emformer_transducer_stateless/stream.py

View File

@ -0,0 +1,980 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./conv_emformer_transducer_stateless2/streaming_decode.py \
--epoch 30 \
--avg 10 \
--exp-dir conv_emformer_transducer_stateless2/exp \
--num-decode-streams 2000 \
--num-encoder-layers 12 \
--chunk-length 32 \
--cnn-module-kernel 31 \
--left-context-length 32 \
--right-context-length 8 \
--memory-size 32 \
--decoding-method greedy_search \
--use-averaged-model True
(2) modified beam search
./conv_emformer_transducer_stateless2/streaming_decode.py \
--epoch 30 \
--avg 10 \
--exp-dir conv_emformer_transducer_stateless2/exp \
--num-decode-streams 2000 \
--num-encoder-layers 12 \
--chunk-length 32 \
--cnn-module-kernel 31 \
--left-context-length 32 \
--right-context-length 8 \
--memory-size 32 \
--decoding-method modified_beam_search \
--use-averaged-model True \
--beam-size 4
(3) fast beam search
./conv_emformer_transducer_stateless2/streaming_decode.py \
--epoch 30 \
--avg 10 \
--exp-dir conv_emformer_transducer_stateless2/exp \
--num-decode-streams 2000 \
--num-encoder-layers 12 \
--chunk-length 32 \
--cnn-module-kernel 31 \
--left-context-length 32 \
--right-context-length 8 \
--memory-size 32 \
--decoding-method fast_beam_search \
--use-averaged-model True \
--beam 4 \
--max-contexts 4 \
--max-states 8
"""
import argparse
import logging
import warnings
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
from lhotse import CutSet
import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
from emformer import LOG_EPSILON, stack_states, unstack_states
from kaldifeat import Fbank, FbankOptions
from stream import Stream
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.decode import one_best_decoding
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=False,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_emformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--sampling-rate",
type=float,
default=16000,
help="Sample rate of the audio",
)
parser.add_argument(
"--num-decode-streams",
type=int,
default=2000,
help="The number of streams that can be decoded parallel",
)
add_model_arguments(parser)
return parser
def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[Stream],
) -> None:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
streams:
A list of Stream objects.
"""
assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = next(model.parameters()).device
T = encoder_out.size(1)
encoder_out = model.joiner.encoder_proj(encoder_out)
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
# decoder_out is of shape (batch_size, 1, decoder_out_dim)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
# logits'shape (batch_size, vocab_size)
logits = logits.squeeze(1).squeeze(1)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
streams[i].hyp.append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
)
decoder_out = model.joiner.decoder_proj(decoder_out)
def modified_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[Stream],
beam: int = 4,
):
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args:
model:
The RNN-T model.
encoder_out:
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
the encoder model.
streams:
A list of stream objects.
beam:
Number of active paths during the beam search.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert len(streams) == encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size = len(streams)
T = encoder_out.size(1)
B = [stream.hyps for stream in streams]
encoder_out = model.joiner.encoder_proj(encoder_out)
for t in range(T):
current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.stack(
[hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, encoder_out_dim)
logits = model.joiner(
current_encoder_out, decoder_out, project_input=False
)
# logits is of shape (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
for i in range(batch_size):
streams[i].hyps = B[i]
def fast_beam_search_one_best(
model: nn.Module,
streams: List[Stream],
encoder_out: torch.Tensor,
processed_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
) -> None:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
the shortest path within the lattice is used as the final output.
Args:
model:
An instance of `Transducer`.
streams:
A list of stream objects.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
processed_lens:
A tensor of shape (N,) containing the number of processed frames
in `encoder_out` before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
"""
assert encoder_out.ndim == 3
context_size = model.decoder.context_size
vocab_size = model.decoder.vocab_size
B, T, C = encoder_out.shape
assert B == len(streams)
config = k2.RnntDecodingConfig(
vocab_size=vocab_size,
decoder_history_len=context_size,
beam=beam,
max_contexts=max_contexts,
max_states=max_states,
)
individual_streams = []
for i in range(B):
individual_streams.append(streams[i].rnnt_decoding_stream)
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
encoder_out = model.joiner.encoder_proj(encoder_out)
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# current_encoder_out is of shape
# (shape.NumElements(), 1, joiner_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(processed_lens.tolist())
best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
for i in range(B):
streams[i].hyp = hyps[i]
def decode_one_chunk(
model: nn.Module,
streams: List[Stream],
params: AttributeDict,
decoding_graph: Optional[k2.Fsa] = None,
) -> List[int]:
"""
Args:
model:
The Transducer model.
streams:
A list of Stream objects.
params:
It is returned by :func:`get_params`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
A list of indexes indicating the finished streams.
"""
device = next(model.parameters()).device
feature_list = []
feature_len_list = []
state_list = []
num_processed_frames_list = []
for stream in streams:
# We should first get `stream.num_processed_frames`
# before calling `stream.get_feature_chunk()`
# since `stream.num_processed_frames` would be updated
num_processed_frames_list.append(stream.num_processed_frames)
feature = stream.get_feature_chunk()
feature_len = feature.size(0)
feature_list.append(feature)
feature_len_list.append(feature_len)
state_list.append(stream.states)
features = pad_sequence(
feature_list, batch_first=True, padding_value=LOG_EPSILON
).to(device)
feature_lens = torch.tensor(feature_len_list, device=device)
num_processed_frames = torch.tensor(
num_processed_frames_list, device=device
)
# Make sure it has at least 1 frame after subsampling, first-and-last-frame cutting, and right context cutting # noqa
tail_length = (
3 * params.subsampling_factor + params.right_context_length + 3
)
if features.size(1) < tail_length:
pad_length = tail_length - features.size(1)
feature_lens += pad_length
features = torch.nn.functional.pad(
features,
(0, 0, 0, pad_length),
mode="constant",
value=LOG_EPSILON,
)
# Stack states of all streams
states = stack_states(state_list)
encoder_out, encoder_out_lens, states = model.encoder.infer(
x=features,
x_lens=feature_lens,
states=states,
num_processed_frames=num_processed_frames,
)
if params.decoding_method == "greedy_search":
greedy_search(
model=model,
streams=streams,
encoder_out=encoder_out,
)
elif params.decoding_method == "modified_beam_search":
modified_beam_search(
model=model,
streams=streams,
encoder_out=encoder_out,
beam=params.beam_size,
)
elif params.decoding_method == "fast_beam_search":
# feature_len is needed to get partial results.
# The rnnt_decoding_stream for fast_beam_search.
fast_beam_search_one_best(
model=model,
streams=streams,
encoder_out=encoder_out,
processed_lens=(num_processed_frames >> 2) + encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
# Update cached states of each stream
state_list = unstack_states(states)
for i, s in enumerate(state_list):
streams[i].states = s
finished_streams = [i for i, stream in enumerate(streams) if stream.done]
return finished_streams
def create_streaming_feature_extractor() -> Fbank:
"""Create a CPU streaming feature extractor.
At present, we assume it returns a fbank feature extractor with
fixed options. In the future, we will support passing in the options
from outside.
Returns:
Return a CPU streaming feature extractor.
"""
opts = FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
return Fbank(opts)
def decode_dataset(
cuts: CutSet,
model: nn.Module,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
):
"""Decode dataset.
Args:
cuts:
Lhotse Cutset containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The Transducer model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
device = next(model.parameters()).device
log_interval = 300
fbank = create_streaming_feature_extractor()
decode_results = []
streams = []
for num, cut in enumerate(cuts):
# Each utterance has a Stream.
stream = Stream(
params=params,
decoding_graph=decoding_graph,
device=device,
LOG_EPS=LOG_EPSILON,
)
stream.set_states(model.encoder.init_states(device))
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
samples = torch.from_numpy(audio).squeeze(0)
feature = fbank(samples)
stream.set_feature(feature)
stream.set_ground_truth(cut.supervisions[0].text)
streams.append(stream)
while len(streams) >= params.num_decode_streams:
finished_streams = decode_one_chunk(
model=model,
streams=streams,
params=params,
decoding_graph=decoding_graph,
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
streams[i].ground_truth.split(),
sp.decode(streams[i].decoding_result()).split(),
)
)
del streams[i]
if num % log_interval == 0:
logging.info(f"Cuts processed until now is {num}.")
while len(streams) > 0:
finished_streams = decode_one_chunk(
model=model,
streams=streams,
params=params,
decoding_graph=decoding_graph,
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
streams[i].ground_truth.split(),
sp.decode(streams[i].decoding_result()).split(),
)
)
del streams[i]
if params.decoding_method == "greedy_search":
key = "greedy_search"
elif params.decoding_method == "fast_beam_search":
key = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
)
else:
key = f"beam_size_{params.beam_size}"
return {key: decode_results}
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=sorted(results))
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"fast_beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
# for streaming
params.suffix += f"-streaming-chunk-length-{params.chunk_length}"
params.suffix += f"-left-context-length-{params.left_context_length}"
params.suffix += f"-right-context-length-{params.right_context_length}"
params.suffix += f"-memory-size-{params.memory_size}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-streaming-decode")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.device = device
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.eval()
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_sets = ["test-clean", "test-other"]
test_cuts = [test_clean_cuts, test_other_cuts]
for test_set, test_cut in zip(test_sets, test_cuts):
results_dict = decode_dataset(
cuts=test_cut,
model=model,
params=params,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
torch.manual_seed(20220410)
main()

View File

@ -0,0 +1,194 @@
#!/usr/bin/env python3
#
# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from emformer import ConvolutionModule, Emformer, stack_states, unstack_states
def test_convolution_module_forward():
B, D = 2, 256
chunk_length = 4
right_context_length = 2
num_chunks = 3
U = num_chunks * chunk_length
R = num_chunks * right_context_length
kernel_size = 31
conv_module = ConvolutionModule(
chunk_length,
right_context_length,
D,
kernel_size,
)
utterance = torch.randn(U, B, D)
right_context = torch.randn(R, B, D)
utterance, right_context = conv_module(utterance, right_context)
assert utterance.shape == (U, B, D), utterance.shape
assert right_context.shape == (R, B, D), right_context.shape
def test_convolution_module_infer():
from emformer import ConvolutionModule
B, D = 2, 256
chunk_length = 4
right_context_length = 2
num_chunks = 1
U = num_chunks * chunk_length
R = num_chunks * right_context_length
kernel_size = 31
conv_module = ConvolutionModule(
chunk_length,
right_context_length,
D,
kernel_size,
)
utterance = torch.randn(U, B, D)
right_context = torch.randn(R, B, D)
cache = torch.randn(B, D, kernel_size - 1)
utterance, right_context, new_cache = conv_module.infer(
utterance, right_context, cache
)
assert utterance.shape == (U, B, D), utterance.shape
assert right_context.shape == (R, B, D), right_context.shape
assert new_cache.shape == (B, D, kernel_size - 1), new_cache.shape
def test_state_stack_unstack():
num_features = 80
chunk_length = 32
encoder_dim = 512
num_encoder_layers = 2
kernel_size = 31
left_context_length = 32
right_context_length = 8
memory_size = 32
model = Emformer(
num_features=num_features,
chunk_length=chunk_length,
subsampling_factor=4,
d_model=encoder_dim,
num_encoder_layers=num_encoder_layers,
cnn_module_kernel=kernel_size,
left_context_length=left_context_length,
right_context_length=right_context_length,
memory_size=memory_size,
)
for batch_size in [1, 2]:
attn_caches = [
[
torch.zeros(memory_size, batch_size, encoder_dim),
torch.zeros(left_context_length // 4, batch_size, encoder_dim),
torch.zeros(
left_context_length // 4,
batch_size,
encoder_dim,
),
]
for _ in range(num_encoder_layers)
]
conv_caches = [
torch.zeros(batch_size, encoder_dim, kernel_size - 1)
for _ in range(num_encoder_layers)
]
states = [attn_caches, conv_caches]
x = torch.randn(batch_size, 23, num_features)
x_lens = torch.full((batch_size,), 23)
num_processed_frames = torch.full((batch_size,), 0)
y, y_lens, states = model.infer(
x, x_lens, num_processed_frames=num_processed_frames, states=states
)
state_list = unstack_states(states)
states2 = stack_states(state_list)
for ss, ss2 in zip(states[0], states2[0]):
for s, s2 in zip(ss, ss2):
assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}"
for s, s2 in zip(states[1], states2[1]):
assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}"
def test_torchscript_consistency_infer():
r"""Verify that scripting Emformer does not change the behavior of method `infer`.""" # noqa
num_features = 80
chunk_length = 32
encoder_dim = 512
num_encoder_layers = 2
kernel_size = 31
left_context_length = 32
right_context_length = 8
memory_size = 32
batch_size = 2
model = Emformer(
num_features=num_features,
chunk_length=chunk_length,
subsampling_factor=4,
d_model=encoder_dim,
num_encoder_layers=num_encoder_layers,
cnn_module_kernel=kernel_size,
left_context_length=left_context_length,
right_context_length=right_context_length,
memory_size=memory_size,
).eval()
attn_caches = [
[
torch.zeros(memory_size, batch_size, encoder_dim),
torch.zeros(left_context_length // 4, batch_size, encoder_dim),
torch.zeros(
left_context_length // 4,
batch_size,
encoder_dim,
),
]
for _ in range(num_encoder_layers)
]
conv_caches = [
torch.zeros(batch_size, encoder_dim, kernel_size - 1)
for _ in range(num_encoder_layers)
]
states = [attn_caches, conv_caches]
x = torch.randn(batch_size, 23, num_features)
x_lens = torch.full((batch_size,), 23)
num_processed_frames = torch.full((batch_size,), 0)
y, y_lens, out_states = model.infer(
x, x_lens, num_processed_frames=num_processed_frames, states=states
)
sc_model = torch.jit.script(model).eval()
sc_y, sc_y_lens, sc_out_states = sc_model.infer(
x, x_lens, num_processed_frames=num_processed_frames, states=states
)
assert torch.allclose(y, sc_y)
if __name__ == "__main__":
test_convolution_module_forward()
test_convolution_module_infer()
test_state_stack_unstack()
test_torchscript_consistency_infer()

File diff suppressed because it is too large Load Diff

View File

@ -77,9 +77,9 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ] && [ ! "$use_extracted_codebook" ==
fi fi
# Install quantization toolkit: # Install quantization toolkit:
# pip install git+https://github.com/danpovey/quantization.git@master # pip install git+https://github.com/k2-fsa/multi_quantization.git
# when testing this code: # or
# commit c17ffe67aa2e6ca6b6855c50fde812f2eed7870b is used. # pip install multi_quantization
has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('quantization') is not None)") has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('quantization') is not None)")
if [ $has_quantization == 'False' ]; then if [ $has_quantization == 'False' ]; then

View File

@ -19,11 +19,12 @@ from dataclasses import dataclass
from typing import Dict, List, Optional from typing import Dict, List, Optional
import k2 import k2
import sentencepiece as spm
import torch import torch
from model import Transducer from model import Transducer
from icefall.decode import Nbest, one_best_decoding from icefall.decode import Nbest, one_best_decoding
from icefall.utils import get_texts from icefall.utils import add_eos, add_sos, get_texts
def fast_beam_search_one_best( def fast_beam_search_one_best(
@ -34,6 +35,7 @@ def fast_beam_search_one_best(
beam: float, beam: float,
max_states: int, max_states: int,
max_contexts: int, max_contexts: int,
temperature: float = 1.0,
) -> List[List[int]]: ) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -44,7 +46,7 @@ def fast_beam_search_one_best(
model: model:
An instance of `Transducer`. An instance of `Transducer`.
decoding_graph: decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG. Decoding graph used for decoding, may be a TrivialGraph or a LG.
encoder_out: encoder_out:
A tensor of shape (N, T, C) from the encoder. A tensor of shape (N, T, C) from the encoder.
encoder_out_lens: encoder_out_lens:
@ -56,6 +58,8 @@ def fast_beam_search_one_best(
Max states per stream per frame. Max states per stream per frame.
max_contexts: max_contexts:
Max contexts pre stream per frame. Max contexts pre stream per frame.
temperature:
Softmax temperature.
Returns: Returns:
Return the decoded result. Return the decoded result.
""" """
@ -67,6 +71,7 @@ def fast_beam_search_one_best(
beam=beam, beam=beam,
max_states=max_states, max_states=max_states,
max_contexts=max_contexts, max_contexts=max_contexts,
temperature=temperature,
) )
best_path = one_best_decoding(lattice) best_path = one_best_decoding(lattice)
@ -85,6 +90,7 @@ def fast_beam_search_nbest_LG(
num_paths: int, num_paths: int,
nbest_scale: float = 0.5, nbest_scale: float = 0.5,
use_double_scores: bool = True, use_double_scores: bool = True,
temperature: float = 1.0,
) -> List[List[int]]: ) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -100,7 +106,7 @@ def fast_beam_search_nbest_LG(
model: model:
An instance of `Transducer`. An instance of `Transducer`.
decoding_graph: decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG. Decoding graph used for decoding, may be a TrivialGraph or a LG.
encoder_out: encoder_out:
A tensor of shape (N, T, C) from the encoder. A tensor of shape (N, T, C) from the encoder.
encoder_out_lens: encoder_out_lens:
@ -120,6 +126,8 @@ def fast_beam_search_nbest_LG(
use_double_scores: use_double_scores:
True to use double precision for computation. False to use True to use double precision for computation. False to use
single precision. single precision.
temperature:
Softmax temperature.
Returns: Returns:
Return the decoded result. Return the decoded result.
""" """
@ -131,6 +139,7 @@ def fast_beam_search_nbest_LG(
beam=beam, beam=beam,
max_states=max_states, max_states=max_states,
max_contexts=max_contexts, max_contexts=max_contexts,
temperature=temperature,
) )
nbest = Nbest.from_lattice( nbest = Nbest.from_lattice(
@ -201,6 +210,7 @@ def fast_beam_search_nbest(
num_paths: int, num_paths: int,
nbest_scale: float = 0.5, nbest_scale: float = 0.5,
use_double_scores: bool = True, use_double_scores: bool = True,
temperature: float = 1.0,
) -> List[List[int]]: ) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -216,7 +226,7 @@ def fast_beam_search_nbest(
model: model:
An instance of `Transducer`. An instance of `Transducer`.
decoding_graph: decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG. Decoding graph used for decoding, may be a TrivialGraph or a LG.
encoder_out: encoder_out:
A tensor of shape (N, T, C) from the encoder. A tensor of shape (N, T, C) from the encoder.
encoder_out_lens: encoder_out_lens:
@ -236,6 +246,8 @@ def fast_beam_search_nbest(
use_double_scores: use_double_scores:
True to use double precision for computation. False to use True to use double precision for computation. False to use
single precision. single precision.
temperature:
Softmax temperature.
Returns: Returns:
Return the decoded result. Return the decoded result.
""" """
@ -247,6 +259,7 @@ def fast_beam_search_nbest(
beam=beam, beam=beam,
max_states=max_states, max_states=max_states,
max_contexts=max_contexts, max_contexts=max_contexts,
temperature=temperature,
) )
nbest = Nbest.from_lattice( nbest = Nbest.from_lattice(
@ -282,6 +295,7 @@ def fast_beam_search_nbest_oracle(
ref_texts: List[List[int]], ref_texts: List[List[int]],
use_double_scores: bool = True, use_double_scores: bool = True,
nbest_scale: float = 0.5, nbest_scale: float = 0.5,
temperature: float = 1.0,
) -> List[List[int]]: ) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -297,7 +311,7 @@ def fast_beam_search_nbest_oracle(
model: model:
An instance of `Transducer`. An instance of `Transducer`.
decoding_graph: decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG. Decoding graph used for decoding, may be a TrivialGraph or a LG.
encoder_out: encoder_out:
A tensor of shape (N, T, C) from the encoder. A tensor of shape (N, T, C) from the encoder.
encoder_out_lens: encoder_out_lens:
@ -321,7 +335,8 @@ def fast_beam_search_nbest_oracle(
nbest_scale: nbest_scale:
It's the scale applied to the lattice.scores. A smaller value It's the scale applied to the lattice.scores. A smaller value
yields more unique paths. yields more unique paths.
temperature:
Softmax temperature.
Returns: Returns:
Return the decoded result. Return the decoded result.
""" """
@ -333,6 +348,7 @@ def fast_beam_search_nbest_oracle(
beam=beam, beam=beam,
max_states=max_states, max_states=max_states,
max_contexts=max_contexts, max_contexts=max_contexts,
temperature=temperature,
) )
nbest = Nbest.from_lattice( nbest = Nbest.from_lattice(
@ -373,6 +389,7 @@ def fast_beam_search(
beam: float, beam: float,
max_states: int, max_states: int,
max_contexts: int, max_contexts: int,
temperature: float = 1.0,
) -> k2.Fsa: ) -> k2.Fsa:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -380,7 +397,7 @@ def fast_beam_search(
model: model:
An instance of `Transducer`. An instance of `Transducer`.
decoding_graph: decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG. Decoding graph used for decoding, may be a TrivialGraph or a LG.
encoder_out: encoder_out:
A tensor of shape (N, T, C) from the encoder. A tensor of shape (N, T, C) from the encoder.
encoder_out_lens: encoder_out_lens:
@ -392,6 +409,8 @@ def fast_beam_search(
Max states per stream per frame. Max states per stream per frame.
max_contexts: max_contexts:
Max contexts pre stream per frame. Max contexts pre stream per frame.
temperature:
Softmax temperature.
Returns: Returns:
Return an FsaVec with axes [utt][state][arc] containing the decoded Return an FsaVec with axes [utt][state][arc] containing the decoded
lattice. Note: When the input graph is a TrivialGraph, the returned lattice. Note: When the input graph is a TrivialGraph, the returned
@ -440,7 +459,7 @@ def fast_beam_search(
project_input=False, project_input=False,
) )
logits = logits.squeeze(1).squeeze(1) logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1) log_probs = (logits / temperature).log_softmax(dim=-1)
decoding_streams.advance(log_probs) decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams() decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist()) lattice = decoding_streams.format_output(encoder_out_lens.tolist())
@ -783,6 +802,7 @@ def modified_beam_search(
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor, encoder_out_lens: torch.Tensor,
beam: int = 4, beam: int = 4,
temperature: float = 1.0,
) -> List[List[int]]: ) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
@ -796,6 +816,8 @@ def modified_beam_search(
encoder_out before padding. encoder_out before padding.
beam: beam:
Number of active paths during the beam search. Number of active paths during the beam search.
temperature:
Softmax temperature.
Returns: Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance. for the i-th utterance.
@ -879,7 +901,9 @@ def modified_beam_search(
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) log_probs = (logits / temperature).log_softmax(
dim=-1
) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs) log_probs.add_(ys_log_probs)
@ -1043,6 +1067,7 @@ def beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
beam: int = 4, beam: int = 4,
temperature: float = 1.0,
) -> List[int]: ) -> List[int]:
""" """
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
@ -1056,6 +1081,8 @@ def beam_search(
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam: beam:
Beam size. Beam size.
temperature:
Softmax temperature.
Returns: Returns:
Return the decoded result. Return the decoded result.
""" """
@ -1132,7 +1159,7 @@ def beam_search(
) )
# TODO(fangjun): Scale the blank posterior # TODO(fangjun): Scale the blank posterior
log_prob = logits.log_softmax(dim=-1) log_prob = (logits / temperature).log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size) # log_prob is (1, 1, 1, vocab_size)
log_prob = log_prob.squeeze() log_prob = log_prob.squeeze()
# Now log_prob is (vocab_size,) # Now log_prob is (vocab_size,)
@ -1171,3 +1198,344 @@ def beam_search(
best_hyp = B.get_most_probable(length_norm=True) best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
return ys return ys
def fast_beam_search_with_nbest_rescoring(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
ngram_lm_scale_list: List[float],
num_paths: int,
G: k2.Fsa,
sp: spm.SentencePieceProcessor,
word_table: k2.SymbolTable,
oov_word: str = "<UNK>",
use_double_scores: bool = True,
nbest_scale: float = 0.5,
temperature: float = 1.0,
) -> Dict[str, List[List[int]]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using fast beam search, num_path are selected
and rescored using a given language model. The shortest path within the
lattice is used as the final output.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a LG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi.
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
ngram_lm_scale_list:
A list of floats representing LM score scales.
num_paths:
Number of paths to extract from the decoded lattice.
G:
An FsaVec containing only a single FSA. It is an n-gram LM.
sp:
The BPE model.
word_table:
The word symbol table.
oov_word:
OOV words are replaced with this word.
use_double_scores:
True to use double precision for computation. False to use
single precision.
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
temperature:
Softmax temperature.
Returns:
Return the decoded result in a dict, where the key has the form
'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the
ngram LM scale value used during decoding, i.e., 0.1.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
)
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# at this point, nbest.fsa.scores are all zeros.
nbest = nbest.intersect(lattice)
# Now nbest.fsa.scores contains acoustic scores
am_scores = nbest.tot_scores()
# Now we need to compute the LM scores of each path.
# (1) Get the token IDs of each Path. We assume the decoding_graph
# is an acceptor, i.e., lattice is also an acceptor
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc]
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous())
tokens = tokens.remove_values_leq(0) # remove -1 and 0
token_list: List[List[int]] = tokens.tolist()
word_list: List[List[str]] = sp.decode(token_list)
assert isinstance(oov_word, str), oov_word
assert oov_word in word_table, oov_word
oov_word_id = word_table[oov_word]
word_ids_list: List[List[int]] = []
for words in word_list:
this_word_ids = []
for w in words.split():
if w in word_table:
this_word_ids.append(word_table[w])
else:
this_word_ids.append(oov_word_id)
word_ids_list.append(this_word_ids)
word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device)
word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas)
num_unique_paths = len(word_ids_list)
b_to_a_map = torch.zeros(
num_unique_paths,
dtype=torch.int32,
device=lattice.device,
)
rescored_word_fsas = k2.intersect_device(
a_fsas=G,
b_fsas=word_fsas_with_self_loops,
b_to_a_map=b_to_a_map,
sorted_match_a=True,
ret_arc_maps=False,
)
rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas)
rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas))
ngram_lm_scores = rescored_word_fsas.get_tot_scores(
use_double_scores=True,
log_semiring=False,
)
ans: Dict[str, List[List[int]]] = {}
for s in ngram_lm_scale_list:
key = f"ngram_lm_scale_{s}"
tot_scores = am_scores.values + s * ngram_lm_scores
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
ans[key] = hyps
return ans
def fast_beam_search_with_nbest_rnn_rescoring(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
ngram_lm_scale_list: List[float],
num_paths: int,
G: k2.Fsa,
sp: spm.SentencePieceProcessor,
word_table: k2.SymbolTable,
rnn_lm_model: torch.nn.Module,
rnn_lm_scale_list: List[float],
oov_word: str = "<UNK>",
use_double_scores: bool = True,
nbest_scale: float = 0.5,
temperature: float = 1.0,
) -> Dict[str, List[List[int]]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using fast beam search, num_path are selected
and rescored using a given language model and a rnn-lm.
The shortest path within the lattice is used as the final output.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a LG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi.
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
ngram_lm_scale_list:
A list of floats representing LM score scales.
num_paths:
Number of paths to extract from the decoded lattice.
G:
An FsaVec containing only a single FSA. It is an n-gram LM.
sp:
The BPE model.
word_table:
The word symbol table.
rnn_lm_model:
A rnn-lm model used for LM rescoring
rnn_lm_scale_list:
A list of floats representing RNN score scales.
oov_word:
OOV words are replaced with this word.
use_double_scores:
True to use double precision for computation. False to use
single precision.
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
temperature:
Softmax temperature.
Returns:
Return the decoded result in a dict, where the key has the form
'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the
ngram LM scale value used during decoding, i.e., 0.1.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
)
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# at this point, nbest.fsa.scores are all zeros.
nbest = nbest.intersect(lattice)
# Now nbest.fsa.scores contains acoustic scores
am_scores = nbest.tot_scores()
# Now we need to compute the LM scores of each path.
# (1) Get the token IDs of each Path. We assume the decoding_graph
# is an acceptor, i.e., lattice is also an acceptor
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc]
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous())
tokens = tokens.remove_values_leq(0) # remove -1 and 0
token_list: List[List[int]] = tokens.tolist()
word_list: List[List[str]] = sp.decode(token_list)
assert isinstance(oov_word, str), oov_word
assert oov_word in word_table, oov_word
oov_word_id = word_table[oov_word]
word_ids_list: List[List[int]] = []
for words in word_list:
this_word_ids = []
for w in words.split():
if w in word_table:
this_word_ids.append(word_table[w])
else:
this_word_ids.append(oov_word_id)
word_ids_list.append(this_word_ids)
word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device)
word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas)
num_unique_paths = len(word_ids_list)
b_to_a_map = torch.zeros(
num_unique_paths,
dtype=torch.int32,
device=lattice.device,
)
rescored_word_fsas = k2.intersect_device(
a_fsas=G,
b_fsas=word_fsas_with_self_loops,
b_to_a_map=b_to_a_map,
sorted_match_a=True,
ret_arc_maps=False,
)
rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas)
rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas))
ngram_lm_scores = rescored_word_fsas.get_tot_scores(
use_double_scores=True,
log_semiring=False,
)
# Now RNN-LM
blank_id = model.decoder.blank_id
sos_id = sp.piece_to_id("sos_id")
eos_id = sp.piece_to_id("eos_id")
sos_tokens = add_sos(tokens, sos_id)
tokens_eos = add_eos(tokens, eos_id)
sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
x_tokens = x_tokens.to(torch.int64)
y_tokens = y_tokens.to(torch.int64)
sentence_lengths = sentence_lengths.to(torch.int64)
rnn_lm_nll = rnn_lm_model(x=x_tokens, y=y_tokens, lengths=sentence_lengths)
assert rnn_lm_nll.ndim == 2
assert rnn_lm_nll.shape[0] == len(token_list)
rnn_lm_scores = -1 * rnn_lm_nll.sum(dim=1)
ans: Dict[str, List[List[int]]] = {}
for n_scale in ngram_lm_scale_list:
for rnn_scale in rnn_lm_scale_list:
key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}"
tot_scores = (
am_scores.values
+ n_scale * ngram_lm_scores
+ rnn_scale * rnn_lm_scores
)
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
ans[key] = hyps
return ans

View File

@ -111,6 +111,8 @@ from beam_search import (
fast_beam_search_nbest_LG, fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle, fast_beam_search_nbest_oracle,
fast_beam_search_one_best, fast_beam_search_one_best,
fast_beam_search_with_nbest_rescoring,
fast_beam_search_with_nbest_rnn_rescoring,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
@ -124,8 +126,10 @@ from icefall.checkpoint import (
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.rnn_lm.model import RnnLmModel
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
load_averaged_model,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
str2bool, str2bool,
@ -312,6 +316,91 @@ def get_parser():
help="left context can be seen during decoding (in frames after subsampling)", help="left context can be seen during decoding (in frames after subsampling)",
) )
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="""Softmax temperature.
The output of the model is (logits / temperature).log_softmax().
""",
)
parser.add_argument(
"--lm-dir",
type=Path,
default=Path("./data/lm"),
help="""Used only when --decoding-method is
fast_beam_search_with_nbest_rescoring.
It should contain either G_4_gram.pt or G_4_gram.fst.txt
""",
)
parser.add_argument(
"--words-txt",
type=Path,
default=Path("./data/lang_bpe_500/words.txt"),
help="""Used only when --decoding-method is
fast_beam_search_with_nbest_rescoring.
It is the word table.
""",
)
parser.add_argument(
"--rnn-lm-exp-dir",
type=str,
default="rnn_lm/exp",
help="""Used only when --method is rnn-lm.
It specifies the path to RNN LM exp dir.
""",
)
parser.add_argument(
"--rnn-lm-epoch",
type=int,
default=7,
help="""Used only when --method is rnn-lm.
It specifies the checkpoint to use.
""",
)
parser.add_argument(
"--rnn-lm-avg",
type=int,
default=2,
help="""Used only when --method is rnn-lm.
It specifies the number of checkpoints to average.
""",
)
parser.add_argument(
"--rnn-lm-embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--rnn-lm-hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--rnn-lm-num-layers",
type=int,
default=4,
help="Number of RNN layers the model",
)
parser.add_argument(
"--rnn-lm-tie-weights",
type=str2bool,
default=True,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -324,6 +413,8 @@ def decode_one_batch(
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
G: Optional[k2.Fsa] = None,
rnn_lm_model: torch.nn.Module = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -352,6 +443,11 @@ def decode_one_batch(
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
G:
Optional. Used only when decoding method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_oracle,
or fast_beam_search_with_nbest_rescoring.
It an FsaVec containing an acceptor.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
@ -397,6 +493,7 @@ def decode_one_batch(
beam=params.beam, beam=params.beam,
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
temperature=params.temperature,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -411,6 +508,7 @@ def decode_one_batch(
max_states=params.max_states, max_states=params.max_states,
num_paths=params.num_paths, num_paths=params.num_paths,
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
temperature=params.temperature,
) )
for hyp in hyp_tokens: for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp]) hyps.append([word_table[i] for i in hyp])
@ -425,6 +523,7 @@ def decode_one_batch(
max_states=params.max_states, max_states=params.max_states,
num_paths=params.num_paths, num_paths=params.num_paths,
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
temperature=params.temperature,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -440,6 +539,7 @@ def decode_one_batch(
num_paths=params.num_paths, num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]), ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
temperature=params.temperature,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -460,9 +560,56 @@ def decode_one_batch(
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
temperature=params.temperature,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_with_nbest_rescoring":
ngram_lm_scale_list = [-0.5, -0.2, -0.1, -0.05, -0.02, 0]
ngram_lm_scale_list += [0.01, 0.02, 0.05]
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.8]
ngram_lm_scale_list += [1.0, 1.5, 2.5, 3]
hyp_tokens = fast_beam_search_with_nbest_rescoring(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_states=params.max_states,
max_contexts=params.max_contexts,
ngram_lm_scale_list=ngram_lm_scale_list,
num_paths=params.num_paths,
G=G,
sp=sp,
word_table=word_table,
use_double_scores=True,
nbest_scale=params.nbest_scale,
temperature=params.temperature,
)
elif params.decoding_method == "fast_beam_search_with_nbest_rnn_rescoring":
ngram_lm_scale_list = [-0.5, -0.2, -0.1, -0.05, -0.02, 0]
ngram_lm_scale_list += [0.01, 0.02, 0.05]
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.8]
ngram_lm_scale_list += [1.0, 1.5, 2.5, 3]
hyp_tokens = fast_beam_search_with_nbest_rnn_rescoring(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_states=params.max_states,
max_contexts=params.max_contexts,
ngram_lm_scale_list=ngram_lm_scale_list,
num_paths=params.num_paths,
G=G,
sp=sp,
word_table=word_table,
rnn_lm_model=rnn_lm_model,
rnn_lm_scale_list=ngram_lm_scale_list,
use_double_scores=True,
nbest_scale=params.nbest_scale,
temperature=params.temperature,
)
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
@ -496,6 +643,7 @@ def decode_one_batch(
f"beam_{params.beam}_" f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_" f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}" f"max_states_{params.max_states}"
f"temperature_{params.temperature}"
): hyps ): hyps
} }
elif params.decoding_method == "fast_beam_search": elif params.decoding_method == "fast_beam_search":
@ -504,8 +652,26 @@ def decode_one_batch(
f"beam_{params.beam}_" f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_" f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}" f"max_states_{params.max_states}"
f"temperature_{params.temperature}"
): hyps ): hyps
} }
elif params.decoding_method in [
"fast_beam_search_with_nbest_rescoring",
"fast_beam_search_with_nbest_rnn_rescoring",
]:
prefix = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}_"
f"num_paths_{params.num_paths}_"
f"nbest_scale_{params.nbest_scale}_"
f"temperature_{params.temperature}_"
)
ans: Dict[str, List[List[str]]] = {}
for key, hyp in hyp_tokens.items():
t: List[str] = sp.decode(hyp)
ans[prefix + key] = [s.split() for s in t]
return ans
elif "fast_beam_search" in params.decoding_method: elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_" key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_" key += f"max_contexts_{params.max_contexts}_"
@ -515,10 +681,14 @@ def decode_one_batch(
key += f"nbest_scale_{params.nbest_scale}" key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method: if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}" key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps} return {key: hyps}
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {
(
f"beam_size_{params.beam_size}_"
f"temperature_{params.temperature}"
): hyps
}
def decode_dataset( def decode_dataset(
@ -528,6 +698,8 @@ def decode_dataset(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
G: Optional[k2.Fsa] = None,
rnn_lm_model: torch.nn.Module = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -546,6 +718,11 @@ def decode_dataset(
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
G:
Optional. Used only when decoding method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_oracle,
or fast_beam_search_with_nbest_rescoring.
It's an FsaVec containing an acceptor.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
@ -576,6 +753,8 @@ def decode_dataset(
word_table=word_table, word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
batch=batch, batch=batch,
G=G,
rnn_lm_model=rnn_lm_model,
) )
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
@ -642,6 +821,71 @@ def save_results(
logging.info(s) logging.info(s)
def load_ngram_LM(
lm_dir: Path, word_table: k2.SymbolTable, device: torch.device
) -> k2.Fsa:
"""Read a ngram model from the given directory.
Args:
lm_dir:
It should contain either G_4_gram.pt or G_4_gram.fst.txt
word_table:
The word table mapping words to IDs and vice versa.
device:
The resulting FSA will be moved to this device.
Returns:
Return an FsaVec containing a single acceptor.
"""
lm_dir = Path(lm_dir)
assert lm_dir.is_dir(), f"{lm_dir} does not exist"
pt_file = lm_dir / "G_4_gram.pt"
if pt_file.is_file():
logging.info(f"Loading pre-compiled {pt_file}")
d = torch.load(pt_file, map_location=device)
G = k2.Fsa.from_dict(d)
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
return G
txt_file = lm_dir / "G_4_gram.fst.txt"
assert txt_file.is_file(), f"{txt_file} does not exist"
logging.info(f"Loading {txt_file}")
logging.warning("It may take 8 minutes (Will be cached for later use).")
with open(txt_file) as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so
# remove it here.
del G.aux_labels
# Now G is an acceptor
first_word_disambig_id = word_table["#0"]
# CAUTION: The following line is crucial.
# Arcs entering the back-off state have label equal to #0.
# We have to change it to 0 here.
G.labels[G.labels >= first_word_disambig_id] = 0
# See https://github.com/k2-fsa/k2/issues/874
# for why we need to set G.properties to None
G.__dict__["_properties"] = None
G = k2.Fsa.from_fsas([G]).to(device)
# Save a dummy value so that it can be loaded in C++.
# See https://github.com/pytorch/pytorch/issues/67902
# for why we need to do this.
G.dummy = 1
logging.info(f"Saving to {pt_file} for later use")
torch.save(G.as_dict(), pt_file)
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
return G
@torch.no_grad() @torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
@ -660,6 +904,8 @@ def main():
"fast_beam_search_nbest_LG", "fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle", "fast_beam_search_nbest_oracle",
"modified_beam_search", "modified_beam_search",
"fast_beam_search_with_nbest_rescoring",
"fast_beam_search_with_nbest_rnn_rescoring",
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
@ -676,6 +922,7 @@ def main():
params.suffix += f"-beam-{params.beam}" params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
params.suffix += f"-temperature-{params.temperature}"
if "nbest" in params.decoding_method: if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}" params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}" params.suffix += f"-num-paths-{params.num_paths}"
@ -685,9 +932,11 @@ def main():
params.suffix += ( params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}" f"-{params.decoding_method}-beam-size-{params.beam_size}"
) )
params.suffix += f"-temperature-{params.temperature}"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"-temperature-{params.temperature}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started") logging.info("Decoding started")
@ -760,14 +1009,59 @@ def main():
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
elif params.decoding_method in [
"fast_beam_search_with_nbest_rescoring",
"fast_beam_search_with_nbest_rnn_rescoring",
]:
logging.info(f"Loading word symbol table from {params.words_txt}")
word_table = k2.SymbolTable.from_file(params.words_txt)
G = load_ngram_LM(
lm_dir=params.lm_dir,
word_table=word_table,
device=device,
)
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
logging.info(f"G properties_str: {G.properties_str}")
rnn_lm_model = None
if (
params.decoding_method
== "fast_beam_search_with_nbest_rnn_rescoring"
):
rnn_lm_model = RnnLmModel(
vocab_size=params.vocab_size,
embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers,
tie_weights=params.rnn_lm_tie_weights,
)
if params.rnn_lm_avg == 1:
load_checkpoint(
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
rnn_lm_model,
)
rnn_lm_model.to(device)
else:
rnn_lm_model = load_averaged_model(
params.rnn_lm_exp_dir,
rnn_lm_model,
params.rnn_lm_epoch,
params.rnn_lm_avg,
device,
)
rnn_lm_model.eval()
else: else:
word_table = None word_table = None
decoding_graph = k2.trivial_graph( decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device params.vocab_size - 1, device=device
) )
rnn_lm_model = None
else: else:
decoding_graph = None decoding_graph = None
word_table = None word_table = None
rnn_lm_model = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -792,6 +1086,8 @@ def main():
sp=sp, sp=sp,
word_table=word_table, word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
G=G,
rnn_lm_model=rnn_lm_model,
) )
save_results( save_results(

View File

@ -1601,10 +1601,6 @@ class RandomCombine(nn.Module):
is a random combination of all the inputs; but which in test time is a random combination of all the inputs; but which in test time
will be just the last input. will be just the last input.
All but the last input will have a linear transform before we
randomly combine them; these linear transforms will be initialized
to the identity transform.
The idea is that the list of Tensors will be a list of outputs of multiple The idea is that the list of Tensors will be a list of outputs of multiple
conformer layers. This has a similar effect as iterated loss. (See: conformer layers. This has a similar effect as iterated loss. (See:
DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
@ -1804,7 +1800,6 @@ def _test_random_combine(final_weight: float, pure_prob: float, stddev: float):
num_channels = 50 num_channels = 50
m = RandomCombine( m = RandomCombine(
num_inputs=num_inputs, num_inputs=num_inputs,
num_channels=num_channels,
final_weight=final_weight, final_weight=final_weight,
pure_prob=pure_prob, pure_prob=pure_prob,
stddev=stddev, stddev=stddev,
@ -1826,9 +1821,7 @@ def _test_random_combine_main():
_test_random_combine(0.5, 0.5, 0.3) _test_random_combine(0.5, 0.5, 0.3)
feature_dim = 50 feature_dim = 50
c = Conformer( c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
num_features=feature_dim, output_dim=256, d_model=128, nhead=4
)
batch_size = 5 batch_size = 5
seq_len = 20 seq_len = 20
# Just make sure the forward pass runs. # Just make sure the forward pass runs.

View File

@ -23,7 +23,7 @@ from scaling import ScaledLinear
from icefall.utils import add_sos from icefall.utils import add_sos
from quantization.prediction import JointCodebookLoss from multi_quantization.prediction import JointCodebookLoss
class Transducer(nn.Module): class Transducer(nn.Module):
@ -75,7 +75,9 @@ class Transducer(nn.Module):
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
if num_codebooks > 0: if num_codebooks > 0:
self.codebook_loss_net = JointCodebookLoss( self.codebook_loss_net = JointCodebookLoss(
predictor_channels=encoder_dim, num_codebooks=num_codebooks predictor_channels=encoder_dim,
num_codebooks=num_codebooks,
is_joint=False,
) )
def forward( def forward(

View File

@ -879,6 +879,11 @@ def run(rank, world_size, args):
The return value of get_parser().parse_args() The return value of get_parser().parse_args()
""" """
params = get_params() params = get_params()
# Note: it's better to set --spec-aug-time-warpi-factor=-1
# when doing distillation with vq.
assert args.spec_aug_time_warp_factor < 1
params.update(vars(args)) params.update(vars(args))
if params.full_libri is False: if params.full_libri is False:
params.valid_interval = 1600 params.valid_interval = 1600

View File

@ -43,7 +43,7 @@ def compute_fbank_wenetspeech_dev_test():
# number of seconds in a batch # number of seconds in a batch
batch_duration = 600 batch_duration = 600
subsets = ("S", "M", "DEV", "TEST_NET", "TEST_MEETING") subsets = ("DEV", "TEST_NET", "TEST_MEETING")
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -63,8 +63,12 @@ def compute_fbank_wenetspeech_dev_test():
logging.info(f"Loading {raw_cuts_path}") logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path) cut_set = CutSet.from_file(raw_cuts_path)
logging.info("Computing features") logging.info("Splitting cuts into smaller chunks")
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, min_duration=None
)
logging.info("Computing features")
cut_set = cut_set.compute_and_store_features_batch( cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor, extractor=extractor,
storage_path=f"{in_out_dir}/feats_{partition}", storage_path=f"{in_out_dir}/feats_{partition}",
@ -72,9 +76,6 @@ def compute_fbank_wenetspeech_dev_test():
batch_duration=batch_duration, batch_duration=batch_duration,
storage_type=LilcomHdf5Writer, storage_type=LilcomHdf5Writer,
) )
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, min_duration=None
)
logging.info(f"Saving to {cuts_path}") logging.info(f"Saving to {cuts_path}")
cut_set.to_file(cuts_path) cut_set.to_file(cuts_path)

View File

@ -128,8 +128,12 @@ def compute_fbank_wenetspeech_splits(args):
logging.info(f"Loading {raw_cuts_path}") logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path) cut_set = CutSet.from_file(raw_cuts_path)
logging.info("Computing features") logging.info("Splitting cuts into smaller chunks.")
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, min_duration=None
)
logging.info("Computing features")
cut_set = cut_set.compute_and_store_features_batch( cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor, extractor=extractor,
storage_path=f"{output_dir}/feats_{subset}_{idx}", storage_path=f"{output_dir}/feats_{subset}_{idx}",
@ -138,14 +142,8 @@ def compute_fbank_wenetspeech_splits(args):
storage_type=LilcomChunkyWriter, storage_type=LilcomChunkyWriter,
) )
logging.info("About to split cuts into smaller chunks.")
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, min_duration=None
)
logging.info(f"Saving to {cuts_path}") logging.info(f"Saving to {cuts_path}")
cut_set.to_file(cuts_path) cut_set.to_file(cuts_path)
logging.info(f"Saved to {cuts_path}")
def main(): def main():

View File

@ -75,6 +75,16 @@ def main():
logging.info("Starting writing the words.txt") logging.info("Starting writing the words.txt")
f_out = open(output_file, "w", encoding="utf-8") f_out = open(output_file, "w", encoding="utf-8")
# LG decoding needs below symbols.
id1, id2, id3 = (
str(len(new_lines)),
str(len(new_lines) + 1),
str(len(new_lines) + 2),
)
add_words = ["#0 " + id1, "<s> " + id2, "</s> " + id3]
new_lines.extend(add_words)
for line in new_lines: for line in new_lines:
f_out.write(line) f_out.write(line)
f_out.write("\n") f_out.write("\n")

View File

@ -1006,6 +1006,8 @@ def rescore_with_rnn_lm(
An FsaVec with axes [utt][state][arc]. An FsaVec with axes [utt][state][arc].
num_paths: num_paths:
Number of paths to extract from the given lattice for rescoring. Number of paths to extract from the given lattice for rescoring.
rnn_lm_model:
A rnn-lm model used for LM rescoring
model: model:
A transformer model. See the class "Transformer" in A transformer model. See the class "Transformer" in
conformer_ctc/transformer.py for its interface. conformer_ctc/transformer.py for its interface.

View File

@ -29,20 +29,10 @@ import torch
def get_git_sha1(): def get_git_sha1():
git_commit = ( try:
subprocess.run( git_commit = (
["git", "rev-parse", "--short", "HEAD"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
dirty_commit = (
len(
subprocess.run( subprocess.run(
["git", "diff", "--shortstat"], ["git", "rev-parse", "--short", "HEAD"],
check=True, check=True,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
) )
@ -50,39 +40,61 @@ def get_git_sha1():
.rstrip("\n") .rstrip("\n")
.strip() .strip()
) )
> 0 dirty_commit = (
) len(
git_commit = ( subprocess.run(
git_commit + "-dirty" if dirty_commit else git_commit + "-clean" ["git", "diff", "--shortstat"],
) check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
> 0
)
git_commit = (
git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
)
except: # noqa
return None
return git_commit return git_commit
def get_git_date(): def get_git_date():
git_date = ( try:
subprocess.run( git_date = (
["git", "log", "-1", "--format=%ad", "--date=local"], subprocess.run(
check=True, ["git", "log", "-1", "--format=%ad", "--date=local"],
stdout=subprocess.PIPE, check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
) )
.stdout.decode() except: # noqa
.rstrip("\n") return None
.strip()
)
return git_date return git_date
def get_git_branch_name(): def get_git_branch_name():
git_date = ( try:
subprocess.run( git_date = (
["git", "rev-parse", "--abbrev-ref", "HEAD"], subprocess.run(
check=True, ["git", "rev-parse", "--abbrev-ref", "HEAD"],
stdout=subprocess.PIPE, check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
) )
.stdout.decode() except: # noqa
.rstrip("\n") return None
.strip()
)
return git_date return git_date

View File

@ -96,8 +96,6 @@ def str2bool(v):
def setup_logger( def setup_logger(
log_filename: Pathlike, log_filename: Pathlike,
log_level: str = "info", log_level: str = "info",
rank: int = 0,
world_size: int = 1,
use_console: bool = True, use_console: bool = True,
) -> None: ) -> None:
"""Setup log level. """Setup log level.
@ -108,16 +106,14 @@ def setup_logger(
log_level: log_level:
The log level to use, e.g., "debug", "info", "warning", "error", The log level to use, e.g., "debug", "info", "warning", "error",
"critical" "critical"
rank:
Rank of this node in DDP training.
world_size:
Number of nodes in DDP training.
use_console: use_console:
True to also print logs to console. True to also print logs to console.
""" """
now = datetime.now() now = datetime.now()
date_time = now.strftime("%Y-%m-%d-%H-%M-%S") date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
if world_size > 1: if dist.is_available() and dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa
log_filename = f"{log_filename}-{date_time}-{rank}" log_filename = f"{log_filename}-{date_time}-{rank}"
else: else:

View File

@ -19,3 +19,4 @@ kaldialign==0.2
sentencepiece==0.1.96 sentencepiece==0.1.96
tensorboard==2.8.0 tensorboard==2.8.0
typeguard==2.13.3 typeguard==2.13.3
multi_quantization

View File

@ -3,3 +3,4 @@ kaldialign
sentencepiece>=0.1.96 sentencepiece>=0.1.96
tensorboard tensorboard
typeguard typeguard
multi_quantization