[recipe] AMI Zipformer transducer (#698)

* remove unnecessary changes

* add AMI prepare scripts

* add zipformer scripts for AMI

* added logs and pretrained model

* minor fix

* remove unwanted changes

* fix missing link

* make suggested changes

* update results
This commit is contained in:
Desh Raj 2022-11-25 21:00:45 -05:00 committed by GitHub
parent 4c636c2cff
commit db75627e92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 3222 additions and 0 deletions

48
egs/ami/ASR/README.md Normal file
View File

@ -0,0 +1,48 @@
# AMI
This is an ASR recipe for the AMI corpus. AMI provides recordings from the speaker's
headset and lapel microphones, and also 2 array microphones containing 8 channels each.
We pool data in the following 4 ways and train a single model on the pooled data:
(i) individual headset microphone (IHM)
(ii) IHM with simulated reverb
(iii) Single distant microphone (SDM)
(iv) GSS-enhanced array microphones
Speed perturbation and MUSAN noise augmentation are additionally performed on the pooled
data. Here are the statistics of the combined training data:
```python
>>> cuts_train.describe()
Cuts count: 1222053
Total duration (hh:mm:ss): 905:00:28
Speech duration (hh:mm:ss): 905:00:28 (99.9%)
Duration statistics (seconds):
mean 2.7
std 2.8
min 0.0
25% 0.6
50% 1.6
75% 3.8
99% 12.3
99.5% 13.9
99.9% 18.4
max 36.8
```
**Note:** This recipe additionally uses [GSS](https://github.com/desh2608/gss) for enhancement
of far-field array microphones, but this is optional (see `prepare.sh` for details).
## Performance Record
### pruned_transducer_stateless7
The following are decoded using `modified_beam_search`:
| Evaluation set | dev WER | test WER |
|--------------------------|------------|---------|
| IHM | 18.92 | 17.40 |
| SDM | 31.25 | 32.21 |
| MDM (GSS-enhanced) | 21.67 | 22.43 |
See [RESULTS](/egs/ami/ASR/RESULTS.md) for details.

92
egs/ami/ASR/RESULTS.md Normal file
View File

@ -0,0 +1,92 @@
## Results
### AMI training results (Pruned Transducer)
#### 2022-11-20
#### Zipformer (pruned_transducer_stateless7)
Zipformer encoder + non-current decoder. The decoder
contains only an embedding layer, a Conv1d (with kernel size 2) and a linear
layer (to transform tensor dim).
All the results below are using a single model that is trained by combining the following
data: IHM, IHM+reverb, SDM, and GSS-enhanced MDM. Speed perturbation and MUSAN noise
augmentation are applied on top of the pooled data.
**WERs for IHM:**
| | dev | test | comment |
|---------------------------|------------|------------|------------------------------------------|
| greedy search | 19.25 | 17.83 | --epoch 14 --avg 8 --max-duration 500 |
| modified beam search | 18.92 | 17.40 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 |
| fast beam search | 19.44 | 18.04 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
**WERs for SDM:**
| | dev | test | comment |
|---------------------------|------------|------------|------------------------------------------|
| greedy search | 31.32 | 32.38 | --epoch 14 --avg 8 --max-duration 500 |
| modified beam search | 31.25 | 32.21 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 |
| fast beam search | 31.11 | 32.10 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
**WERs for GSS-enhanced MDM:**
| | dev | test | comment |
|---------------------------|------------|------------|------------------------------------------|
| greedy search | 22.05 | 22.93 | --epoch 14 --avg 8 --max-duration 500 |
| modified beam search | 21.67 | 22.43 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 |
| fast beam search | 22.21 | 22.83 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
The training command for reproducing is given below:
```
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless7/train.py \
--world-size 4 \
--num-epochs 15 \
--exp-dir pruned_transducer_stateless7/exp \
--max-duration 150 \
--max-cuts 150 \
--prune-range 5 \
--lr-factor 5 \
--lm-scale 0.25 \
--use-fp16 True
```
The decoding command is:
```
# greedy search
./pruned_transducer_stateless7/decode.py \
--epoch 14 \
--avg 8 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 500 \
--decoding-method greedy_search
# modified beam search
./pruned_transducer_stateless7/decode.py \
--iter 105000 \
--avg 10 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 500 \
--decoding-method modified_beam_search \
--beam-size 4
# fast beam search
./pruned_transducer_stateless7/decode.py \
--iter 105000 \
--avg 10 \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
```
Pretrained model is available at <https://huggingface.co/desh2608/icefall-asr-ami-pruned-transducer-stateless7>
The tensorboard training log can be found at
<https://tensorboard.dev/experiment/VH10QOTBTbuYpWx994Onrg/#scalars>

View File

View File

@ -0,0 +1,194 @@
#!/usr/bin/env python3
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the AMI dataset.
For the training data, we pool together IHM, reverberated IHM, and GSS-enhanced
audios. For the test data, we separately prepare IHM, SDM, and GSS-enhanced
parts (which are the 3 evaluation settings).
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import logging
import math
from pathlib import Path
import torch
import torch.multiprocessing
from lhotse import CutSet, LilcomChunkyWriter
from lhotse.features.kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
KaldifeatFrameOptions,
KaldifeatMelOptions,
)
from lhotse.recipes.utils import read_manifests_if_cached
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
torch.multiprocessing.set_sharing_strategy("file_system")
def compute_fbank_ami():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
sampling_rate = 16000
num_mel_bins = 80
extractor = KaldifeatFbank(
KaldifeatFbankConfig(
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
device="cuda",
)
)
logging.info("Reading manifests")
manifests_ihm = read_manifests_if_cached(
dataset_parts=["train", "dev", "test"],
output_dir=src_dir,
prefix="ami-ihm",
suffix="jsonl.gz",
)
manifests_sdm = read_manifests_if_cached(
dataset_parts=["train", "dev", "test"],
output_dir=src_dir,
prefix="ami-sdm",
suffix="jsonl.gz",
)
# For GSS we already have cuts so we read them directly.
manifests_gss = read_manifests_if_cached(
dataset_parts=["train", "dev", "test"],
output_dir=src_dir,
prefix="ami-gss",
suffix="jsonl.gz",
)
def _extract_feats(cuts: CutSet, storage_path: Path, manifest_path: Path) -> None:
cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1)
_ = cuts.compute_and_store_features_batch(
extractor=extractor,
storage_path=storage_path,
manifest_path=manifest_path,
batch_duration=5000,
num_workers=8,
storage_type=LilcomChunkyWriter,
)
logging.info(
"Preparing training cuts: IHM + reverberated IHM + SDM + GSS (optional)"
)
logging.info("Processing train split IHM")
cuts_ihm = (
CutSet.from_manifests(**manifests_ihm["train"])
.trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
.modify_ids(lambda x: x + "-ihm")
)
_extract_feats(
cuts_ihm,
output_dir / "feats_train_ihm",
src_dir / "cuts_train_ihm.jsonl.gz",
)
logging.info("Processing train split IHM + reverberated IHM")
cuts_ihm_rvb = cuts_ihm.reverb_rir()
_extract_feats(
cuts_ihm_rvb,
output_dir / "feats_train_ihm_rvb",
src_dir / "cuts_train_ihm_rvb.jsonl.gz",
)
logging.info("Processing train split SDM")
cuts_sdm = (
CutSet.from_manifests(**manifests_sdm["train"])
.trim_to_supervisions(keep_overlapping=False)
.modify_ids(lambda x: x + "-sdm")
)
_extract_feats(
cuts_sdm,
output_dir / "feats_train_sdm",
src_dir / "cuts_train_sdm.jsonl.gz",
)
logging.info("Processing train split GSS")
cuts_gss = (
CutSet.from_manifests(**manifests_gss["train"])
.trim_to_supervisions(keep_overlapping=False)
.modify_ids(lambda x: x + "-gss")
)
_extract_feats(
cuts_gss,
output_dir / "feats_train_gss",
src_dir / "cuts_train_gss.jsonl.gz",
)
logging.info("Preparing test cuts: IHM, SDM, GSS (optional)")
for split in ["dev", "test"]:
logging.info(f"Processing {split} IHM")
cuts_ihm = (
CutSet.from_manifests(**manifests_ihm[split])
.trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
.compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / f"feats_{split}_ihm",
manifest_path=src_dir / f"cuts_{split}_ihm.jsonl.gz",
batch_duration=5000,
num_workers=8,
storage_type=LilcomChunkyWriter,
)
)
logging.info(f"Processing {split} SDM")
cuts_sdm = (
CutSet.from_manifests(**manifests_sdm[split])
.trim_to_supervisions(keep_overlapping=False)
.compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / f"feats_{split}_sdm",
manifest_path=src_dir / f"cuts_{split}_sdm.jsonl.gz",
batch_duration=500,
num_workers=4,
storage_type=LilcomChunkyWriter,
)
)
logging.info(f"Processing {split} GSS")
cuts_gss = (
CutSet.from_manifests(**manifests_gss[split])
.trim_to_supervisions(keep_overlapping=False)
.compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / f"feats_{split}_gss",
manifest_path=src_dir / f"cuts_{split}_gss.jsonl.gz",
batch_duration=500,
num_workers=4,
storage_type=LilcomChunkyWriter,
)
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_ami()

View File

@ -0,0 +1,114 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the musan dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import logging
from pathlib import Path
import torch
from lhotse import CutSet, LilcomChunkyWriter, combine
from lhotse.features.kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
KaldifeatFrameOptions,
KaldifeatMelOptions,
)
from lhotse.recipes.utils import read_manifests_if_cached
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_musan():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
sampling_rate = 16000
num_mel_bins = 80
dataset_parts = (
"music",
"speech",
"noise",
)
prefix = "musan"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
musan_cuts_path = src_dir / "musan_cuts.jsonl.gz"
if musan_cuts_path.is_file():
logging.info(f"{musan_cuts_path} already exists - skipping")
return
logging.info("Extracting features for Musan")
extractor = KaldifeatFbank(
KaldifeatFbankConfig(
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
device="cuda",
)
)
# create chunks of Musan with duration 5 - 10 seconds
_ = (
CutSet.from_manifests(
recordings=combine(part["recordings"] for part in manifests.values())
)
.cut_into_windows(10.0)
.filter(lambda c: c.duration > 5)
.compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / "musan_feats",
manifest_path=musan_cuts_path,
batch_duration=500,
num_workers=4,
storage_type=LilcomChunkyWriter,
)
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_musan()

View File

@ -0,0 +1,158 @@
#!/usr/local/bin/python
# -*- coding: utf-8 -*-
# Data preparation for AMI GSS-enhanced dataset.
import logging
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from lhotse import Recording, RecordingSet, SupervisionSet
from lhotse.qa import fix_manifests
from lhotse.recipes.utils import read_manifests_if_cached
from lhotse.utils import fastcopy
from tqdm import tqdm
logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)
def get_args():
import argparse
parser = argparse.ArgumentParser(description="AMI enhanced dataset preparation.")
parser.add_argument(
"manifests_dir",
type=Path,
help="Path to directory containing AMI manifests.",
)
parser.add_argument(
"enhanced_dir",
type=Path,
help="Path to enhanced data directory.",
)
parser.add_argument(
"--num-jobs",
"-j",
type=int,
default=1,
help="Number of parallel jobs to run.",
)
parser.add_argument(
"--min-segment-duration",
"-d",
type=float,
default=0.0,
help="Minimum duration of a segment in seconds.",
)
return parser.parse_args()
def find_recording_and_create_new_supervision(enhanced_dir, supervision):
"""
Given a supervision (corresponding to original AMI recording), this function finds the
enhanced recording correspoding to the supervision, and returns this recording and
a new supervision whose start and end times are adjusted to match the enhanced recording.
"""
file_name = Path(
f"{supervision.recording_id}-{supervision.speaker}-{int(100*supervision.start):06d}_{int(100*supervision.end):06d}.flac"
)
save_path = enhanced_dir / f"{supervision.recording_id}" / file_name
if save_path.exists():
recording = Recording.from_file(save_path)
if recording.duration == 0:
logging.warning(f"Skipping {save_path} which has duration 0 seconds.")
return None
# Old supervision is wrt to the original recording, we create new supervision
# wrt to the enhanced segment
new_supervision = fastcopy(
supervision,
recording_id=recording.id,
start=0,
duration=recording.duration,
)
return recording, new_supervision
else:
logging.warning(f"{save_path} does not exist.")
return None
def main(args):
# Get arguments
manifests_dir = args.manifests_dir
enhanced_dir = args.enhanced_dir
# Load manifests from cache if they exist (saves time)
manifests = read_manifests_if_cached(
dataset_parts=["train", "dev", "test"],
output_dir=manifests_dir,
prefix="ami-sdm",
suffix="jsonl.gz",
)
if not manifests:
raise ValueError("AMI SDM manifests not found in {}".format(manifests_dir))
with ThreadPoolExecutor(args.num_jobs) as ex:
for part in ["train", "dev", "test"]:
logging.info(f"Processing {part}...")
supervisions_orig = manifests[part]["supervisions"].filter(
lambda s: s.duration >= args.min_segment_duration
)
# Remove TS3009d supervisions since they are not present in the enhanced data
supervisions_orig = supervisions_orig.filter(
lambda s: s.recording_id != "TS3009d"
)
futures = []
for supervision in tqdm(
supervisions_orig,
desc="Distributing tasks",
):
futures.append(
ex.submit(
find_recording_and_create_new_supervision,
enhanced_dir,
supervision,
)
)
recordings = []
supervisions = []
for future in tqdm(
futures,
total=len(futures),
desc="Processing tasks",
):
result = future.result()
if result is not None:
recording, new_supervision = result
recordings.append(recording)
supervisions.append(new_supervision)
# Remove duplicates from the recordings
recordings_nodup = {}
for recording in recordings:
if recording.id not in recordings_nodup:
recordings_nodup[recording.id] = recording
else:
logging.warning("Recording {} is duplicated.".format(recording.id))
recordings = RecordingSet.from_recordings(recordings_nodup.values())
supervisions = SupervisionSet.from_segments(supervisions)
recordings, supervisions = fix_manifests(
recordings=recordings, supervisions=supervisions
)
logging.info(f"Writing {part} enhanced manifests")
recordings.to_file(manifests_dir / f"ami-gss_recordings_{part}.jsonl.gz")
supervisions.to_file(
manifests_dir / f"ami-gss_supervisions_{part}.jsonl.gz"
)
if __name__ == "__main__":
args = get_args()
main(args)

View File

@ -0,0 +1,98 @@
#!/bin/bash
# This script is used to run GSS-based enhancement on AMI data.
set -euo pipefail
nj=4
stage=0
. shared/parse_options.sh || exit 1
if [ $# != 2 ]; then
echo "Wrong #arguments ($#, expected 2)"
echo "Usage: local/prepare_ami_gss.sh [options] <data-dir> <exp-dir>"
echo "e.g. local/prepare_ami_gss.sh data/manifests exp/ami_gss"
echo "main options (for others, see top of script file)"
echo " --nj <nj> # number of parallel jobs"
echo " --stage <stage> # stage to start running from"
exit 1;
fi
DATA_DIR=$1
EXP_DIR=$2
mkdir -p $EXP_DIR
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 1 ]; then
log "Stage 1: Prepare cut sets"
for part in train dev test; do
lhotse cut simple \
-r $DATA_DIR/ami-mdm_recordings_${part}.jsonl.gz \
-s $DATA_DIR/ami-mdm_supervisions_${part}.jsonl.gz \
$EXP_DIR/cuts_${part}.jsonl.gz
done
fi
if [ $stage -le 2 ]; then
log "Stage 2: Trim cuts to supervisions (1 cut per supervision segment)"
for part in train dev test; do
lhotse cut trim-to-supervisions --discard-overlapping \
$EXP_DIR/cuts_${part}.jsonl.gz $EXP_DIR/cuts_per_segment_${part}.jsonl.gz
done
fi
if [ $stage -le 3 ]; then
log "Stage 3: Split manifests for multi-GPU processing (optional)"
for part in train; do
gss utils split $nj $EXP_DIR/cuts_per_segment_${part}.jsonl.gz \
$EXP_DIR/cuts_per_segment_${part}_split$nj
done
fi
if [ $stage -le 4 ]; then
log "Stage 4: Enhance train segments using GSS (requires GPU)"
# for train, we use smaller context and larger batches to speed-up processing
for JOB in $(seq $nj); do
gss enhance cuts $EXP_DIR/cuts_train.jsonl.gz \
$EXP_DIR/cuts_per_segment_train_split$nj/cuts_per_segment_train.JOB.jsonl.gz $EXP_DIR/enhanced \
--bss-iterations 10 \
--context-duration 5.0 \
--use-garbage-class \
--channels 0,1,2,3,4,5,6,7 \
--min-segment-length 0.05 \
--max-segment-length 35.0 \
--max-batch-duration 60.0 \
--num-buckets 3 \
--num-workers 2
done
fi
if [ $stage -le 5 ]; then
log "Stage 5: Enhance dev/test segments using GSS (using GPU)"
# for dev/test, we use larger context and smaller batches to get better quality
for part in dev test; do
for JOB in $(seq $nj); do
gss enhance cuts $EXP_DIR/cuts_${part}.jsonl.gz \
$EXP_DIR/cuts_per_segment_${part}_split$nj/cuts_per_segment_${part}.JOB.jsonl.gz \
$EXP_DIR/enhanced \
--bss-iterations 10 \
--context-duration 15.0 \
--use-garbage-class \
--channels 0,1,2,3,4,5,6,7 \
--min-segment-length 0.05 \
--max-segment-length 30.0 \
--max-batch-duration 45.0 \
--num-buckets 3 \
--num-workers 2
done
done
fi
if [ $stage -le 6 ]; then
log "Stage 6: Prepare manifests for GSS-enhanced data"
python local/prepare_ami_enhanced.py $DATA_DIR $EXP_DIR/enhanced -j $nj --min-segment-duration 0.05
fi

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/prepare_lang_bpe.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/train_bpe_model.py

144
egs/ami/ASR/prepare.sh Executable file
View File

@ -0,0 +1,144 @@
#!/usr/bin/env bash
set -eou pipefail
stage=-1
stop_stage=100
use_gss=true # Use GSS-based enhancement with MDM setting
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# - $dl_dir/amicorpus
# You can find audio and transcripts in this path.
#
# - $dl_dir/musan
# This directory contains the following directories downloaded from
# http://www.openslr.org/17/
#
# - music
# - noise
# - speech
#
# - $dl_dir/{LDC2004S13,LDC2005S13,LDC2004T19,LDC2005T19}
# These contain the Fisher English audio and transcripts. We will
# only use the transcripts as extra LM training data (similar to Kaldi).
#
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
vocab_size=500
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
# If you have pre-downloaded it to /path/to/amicorpus,
# you can create a symlink
#
# ln -sfv /path/to/amicorpus $dl_dir/amicorpus
#
if [ ! -d $dl_dir/amicorpus ]; then
lhotse download ami --mic ihm $dl_dir/amicorpus
lhotse download ami --mic mdm $dl_dir/amicorpus
fi
# If you have pre-downloaded it to /path/to/musan,
# you can create a symlink
#
# ln -sfv /path/to/musan $dl_dir/
#
if [ ! -d $dl_dir/musan ]; then
lhotse download musan $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare AMI manifests"
# We assume that you have downloaded the AMI corpus
# to $dl_dir/amicorpus. We perform text normalization for the transcripts.
mkdir -p data/manifests
for mic in ihm sdm mdm; do
lhotse prepare ami --mic $mic --partition full-corpus-asr --normalize-text kaldi \
--max-words-per-segment 30 $dl_dir/amicorpus data/manifests/
done
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Prepare musan manifest"
# We assume that you have downloaded the musan corpus
# to $dl_dir/musan
mkdir -p data/manifests
lhotse prepare musan $dl_dir/musan data/manifests
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ] && [ $use_gss = true ]; then
log "Stage 3: Apply GSS enhancement on MDM data (this stage requires a GPU)"
# We assume that you have installed the GSS package: https://github.com/desh2608/gss
local/prepare_ami_gss.sh data/manifests exp/ami_gss
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank features for AMI"
mkdir -p data/fbank
python local/compute_fbank_ami.py
log "Combine features from train splits"
lhotse combine data/manifests/cuts_train_{ihm,ihm_rvb,sdm,gss}.jsonl.gz - | shuf |\
gzip -c > data/manifests/cuts_train_all.jsonl.gz
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compute fbank features for musan"
mkdir -p data/fbank
python local/compute_fbank_musan.py
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Dump transcripts for BPE model training."
mkdir -p data/lm
cat <(gunzip -c data/manifests/ami-sdm_supervisions_train.jsonl.gz | jq '.text' | sed 's:"::g')> data/lm/transcript_words.txt
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Prepare BPE based lang"
lang_dir=data/lang_bpe_${vocab_size}
mkdir -p $lang_dir
# Add special words to words.txt
echo "<eps> 0" > $lang_dir/words.txt
echo "!SIL 1" >> $lang_dir/words.txt
echo "<UNK> 2" >> $lang_dir/words.txt
# Add regular words to words.txt
cat data/lm/transcript_words.txt | grep -o -E '\w+' | sort -u | awk '{print $0,NR+2}' >> $lang_dir/words.txt
# Add remaining special word symbols expected by LM scripts.
num_words=$(cat $lang_dir/words.txt | wc -l)
echo "<s> ${num_words}" >> $lang_dir/words.txt
num_words=$(cat $lang_dir/words.txt | wc -l)
echo "</s> ${num_words}" >> $lang_dir/words.txt
num_words=$(cat $lang_dir/words.txt | wc -l)
echo "#0 ${num_words}" >> $lang_dir/words.txt
./local/train_bpe_model.py \
--lang-dir $lang_dir \
--vocab-size $vocab_size \
--transcript data/lm/transcript_words.txt
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang_bpe.py --lang-dir $lang_dir
fi
fi

View File

@ -0,0 +1,430 @@
# Copyright 2021 Piotr Żelasko
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import re
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.cut import Cut
from lhotse.dataset import (
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from tqdm import tqdm
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class AmiAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description=(
"These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc."
),
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/manifests"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help=(
"When enabled, select noise from MUSAN and mix it "
"with training dataset. "
),
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help=(
"When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding."
),
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help=(
"Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch."
),
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help=(
"The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used."
),
)
group.add_argument(
"--max-duration",
type=int,
default=100.0,
help=(
"Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM."
),
)
group.add_argument(
"--max-cuts", type=int, default=None, help="Maximum cuts in a single batch."
)
group.add_argument(
"--num-buckets",
type=int,
default=50,
help=(
"The number of buckets for the BucketingSampler"
"(you might want to increase it for larger datasets)."
),
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help=(
"When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available."
),
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help=(
"When enabled (=default), the examples will be "
"shuffled for each epoch."
),
)
group.add_argument(
"--num-workers",
type=int,
default=8,
help=(
"The number of training dataloader workers that " "collect the batches."
),
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help=(
"Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp."
),
)
group.add_argument(
"--ihm-only",
type=str2bool,
default=False,
help="When enabled, only use IHM data for training.",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to get Musan cuts")
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=2,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
if self.args.on_the_fly_feats:
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
)
else:
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
)
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
max_cuts=self.args.max_cuts,
shuffle=False,
num_buckets=self.args.num_buckets,
drop_last=True,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=True,
)
sampler = DynamicBucketingSampler(
cuts, max_duration=self.args.max_duration, shuffle=False
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
def remove_short_cuts(self, cut: Cut) -> bool:
"""
See: https://github.com/k2-fsa/icefall/issues/500
Basically, the zipformer model subsamples the input using the following formula:
num_out_frames = (num_in_frames - 7)//2
For num_out_frames to be at least 1, num_in_frames must be at least 9.
"""
return cut.duration >= 0.09
@lru_cache()
def train_cuts(self, sp: Optional[Any] = None) -> CutSet:
logging.info("About to get AMI train cuts")
def _remove_short_and_long_utt(c: Cut):
if c.duration < 0.2 or c.duration > 25.0:
return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./zipformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 7) // 2 + 1) // 2
tokens = sp.encode(c.supervisions[0].text, out_type=str)
return T >= len(tokens)
if self.args.ihm_only:
cuts_train = load_manifest_lazy(
self.args.manifest_dir / "cuts_train_ihm.jsonl.gz"
)
else:
cuts_train = load_manifest_lazy(
self.args.manifest_dir / "cuts_train_all.jsonl.gz"
)
return cuts_train.filter(_remove_short_and_long_utt)
@lru_cache()
def dev_ihm_cuts(self) -> CutSet:
logging.info("About to get AMI IHM dev cuts")
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_ihm.jsonl.gz")
return cs.filter(self.remove_short_cuts)
@lru_cache()
def dev_sdm_cuts(self) -> CutSet:
logging.info("About to get AMI SDM dev cuts")
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_sdm.jsonl.gz")
return cs.filter(self.remove_short_cuts)
@lru_cache()
def dev_gss_cuts(self) -> CutSet:
if not (self.args.manifest_dir / "cuts_dev_gss.jsonl.gz").exists():
logging.info("No GSS dev cuts found")
return None
logging.info("About to get AMI GSS-enhanced dev cuts")
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_gss.jsonl.gz")
return cs.filter(self.remove_short_cuts)
@lru_cache()
def test_ihm_cuts(self) -> CutSet:
logging.info("About to get AMI IHM test cuts")
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_ihm.jsonl.gz")
return cs.filter(self.remove_short_cuts)
@lru_cache()
def test_sdm_cuts(self) -> CutSet:
logging.info("About to get AMI SDM test cuts")
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_sdm.jsonl.gz")
return cs.filter(self.remove_short_cuts)
@lru_cache()
def test_gss_cuts(self) -> CutSet:
if not (self.args.manifest_dir / "cuts_test_gss.jsonl.gz").exists():
logging.info("No GSS test cuts found")
return None
logging.info("About to get AMI GSS-enhanced test cuts")
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_gss.jsonl.gz")
return cs.filter(self.remove_short_cuts)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/beam_search.py

View File

@ -0,0 +1,747 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless7/decode.py \
--iter 105000 \
--avg 10 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 100 \
--decoding-method greedy_search
(2) beam search
./pruned_transducer_stateless7/decode.py \
--iter 105000 \
--avg 10 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 500 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless7/decode.py \
--iter 105000 \
--avg 10 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 500 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
./pruned_transducer_stateless7/decode.py \
--iter 105000 \
--avg 10 \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import AmiAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest_LG,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall import NgramLm
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=10,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
word_table: Optional[k2.SymbolTable] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
word_table:
The word symbol table.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = model.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
word_table: Optional[k2.SymbolTable] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 100
else:
log_interval = 2
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
test_set_cers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
wers_filename = (
params.res_dir / f"wers-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(wers_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
# we also compute CER for AMI dataset.
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
cers_filename = (
params.res_dir / f"cers-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(cers_filename, "w") as f:
cer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
)
test_set_cers[key] = cer
logging.info("Wrote detailed error stats to {}".format(wers_filename))
test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])}
test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])}
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER\tCER", file=f)
for key in test_set_wers:
print(
"{}\t{}\t{}".format(key, test_set_wers[key], test_set_cers[key]),
file=f,
)
s = "\nFor {}, WER/CER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key in test_set_wers:
s += "{}\t{}\t{}{}\n".format(key, test_set_wers[key], test_set_cers[key], note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
AmiAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest_LG",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(f"{params.lang_dir}/bpe.model")
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
model.device = device
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
ami = AmiAsrDataModule(args)
dev_ihm_cuts = ami.dev_ihm_cuts()
test_ihm_cuts = ami.test_ihm_cuts()
dev_sdm_cuts = ami.dev_sdm_cuts()
test_sdm_cuts = ami.test_sdm_cuts()
dev_gss_cuts = ami.dev_gss_cuts()
test_gss_cuts = ami.test_gss_cuts()
dev_ihm_dl = ami.test_dataloaders(dev_ihm_cuts)
test_ihm_dl = ami.test_dataloaders(test_ihm_cuts)
dev_sdm_dl = ami.test_dataloaders(dev_sdm_cuts)
test_sdm_dl = ami.test_dataloaders(test_sdm_cuts)
if dev_gss_cuts is not None:
dev_gss_dl = ami.test_dataloaders(dev_gss_cuts)
if test_gss_cuts is not None:
test_gss_dl = ami.test_dataloaders(test_gss_cuts)
test_sets = {
"dev_ihm": (dev_ihm_dl, dev_ihm_cuts),
"test_ihm": (test_ihm_dl, test_ihm_cuts),
"dev_sdm": (dev_sdm_dl, dev_sdm_cuts),
"test_sdm": (test_sdm_dl, test_sdm_cuts),
}
if dev_gss_cuts is not None:
test_sets["dev_gss"] = (dev_gss_dl, dev_gss_cuts)
if test_gss_cuts is not None:
test_sets["test_gss"] = (test_gss_dl, test_gss_cuts)
for test_set in test_sets:
logging.info(f"Decoding {test_set}")
dl, cuts = test_sets[test_set]
results_dict = decode_dataset(
dl=dl,
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/encoder_interface.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/export.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/optim.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py

1
egs/ami/ASR/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared